{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "82acAhWYGIPx"
      },
      "source": [
        "# Angular Steering\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j7hOtw7UHXdD"
      },
      "source": [
        "This notebook contains:\n",
        "\n",
        "- The process of extracting the refusal direction and constructing the steering plane.\n",
        "- Visualization of the activation, extracted directions and constructed steering planes.\n",
        "- The creation of the steering config that can be used with our fork of vLLM.\n",
        "\n",
        "This notebook is inspired by https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcxHyDZw6b86"
      },
      "source": [
        "## Setup\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eSAh5Q3mXxnK"
      },
      "source": [
        "### Dependencies\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dLeei4-T6Wef"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama nbformat plotly datasets pandas scikit-learn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_vhhwl-2-jPg"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import functools\n",
        "import einops\n",
        "import requests\n",
        "import pandas as pd\n",
        "import io\n",
        "import textwrap\n",
        "import gc\n",
        "import numpy as np\n",
        "import plotly\n",
        "import os\n",
        "\n",
        "\n",
        "from pathlib import Path\n",
        "from datasets import load_dataset\n",
        "from sklearn.model_selection import train_test_split\n",
        "from tqdm import tqdm\n",
        "from torch import Tensor\n",
        "from typing import List\n",
        "from transformer_lens import HookedTransformer, utils, ActivationCache\n",
        "from transformer_lens.hook_points import HookPoint\n",
        "from transformers import AutoTokenizer\n",
        "from jaxtyping import Float, Int\n",
        "from colorama import Fore\n",
        "import plotly.graph_objects as go\n",
        "import plotly.express as px"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ZOoJagxD49V"
      },
      "source": [
        "### Models and configs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1M8QkxefXxnN"
      },
      "source": [
        "For Gemma 2 Instruct Models, the lm_head module is not there and is generally tied with embed_tokens.weight. The final lm_head should be a transpose of it.\n",
        "\n",
        "In HookedTransformers, the embedding is characterized by embed.W_E, and since we have, from the documentation:\n",
        "\n",
        "\"Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match HF implementation\"\n",
        "\n",
        "embed.W_E = embed_tokens.weight * sqrt(D) <- Transformed original embedding\n",
        "\n",
        "unmebed.W_U = embed_tokens.weight.T or lm_head.T\n",
        "\n",
        "Note the bfloat/float16/float32 computation required\n",
        "fl(fl(embed_tokens.weight) * fl(sqrt(fl(D))))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zjsyzD5VXxnN"
      },
      "outputs": [],
      "source": [
        "from safetensors.torch import load_file\n",
        "\n",
        "# This function follows the weight conversions in gemma.py from:\n",
        "# https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/pretrained/weight_conversions/gemma.py\n",
        "def hf_state_dict_to_HT_state_dict_gemma(hf_state_dict: dict,\n",
        "                                         d_model = 3584,\n",
        "                                         n_layers = 42,\n",
        "                                         n_heads = 16,\n",
        "                                         n_key_value_heads = 8,\n",
        "                                         d_head = 256,\n",
        "                                         d_mlp = 14336,\n",
        "                                         d_vocab = 256000,\n",
        "                                         dtype = torch.bfloat16):\n",
        "    # Initialize State Dict\n",
        "    gemma_ht_state_dict = {}\n",
        "\n",
        "    #  Load the Embeddings\n",
        "    gemma_ht_state_dict[\"embed.W_E\"] = hf_state_dict[\"embed_tokens.weight\"] * torch.tensor(d_model ** 0.5, dtype = dtype)\n",
        "\n",
        "    #  Load the layer weights\n",
        "    for l in range(n_layers):\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.ln1.w\"] = hf_state_dict[f\"layers.{l}.input_layernorm.weight\"].float() \\\n",
        "            + torch.ones_like(hf_state_dict[f\"layers.{l}.input_layernorm.weight\"], dtype= torch.float32)\n",
        "\n",
        "        #  Assume that there's post normalization\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.ln1_post.w\"] = hf_state_dict[f\"layers.{l}.post_attention_layernorm.weight\"].float() \\\n",
        "            + torch.ones_like(hf_state_dict[f\"layers.{l}.post_attention_layernorm.weight\"], dtype= torch.float32)\n",
        "\n",
        "        #  Attention Weights\n",
        "        W_Q = hf_state_dict[f\"layers.{l}.self_attn.q_proj.weight\"]\n",
        "        W_K = hf_state_dict[f\"layers.{l}.self_attn.k_proj.weight\"]\n",
        "        W_V = hf_state_dict[f\"layers.{l}.self_attn.v_proj.weight\"]\n",
        "        W_Q = einops.rearrange(W_Q, \"(n h) m->n m h\", n=n_heads)\n",
        "        W_K = einops.rearrange(W_K, \"(n h) m->n m h\", n=n_key_value_heads)\n",
        "        W_V = einops.rearrange(W_V, \"(n h) m->n m h\", n=n_key_value_heads)\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn.W_Q\"] = W_Q\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn._W_K\"] = W_K\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn._W_V\"] = W_V\n",
        "\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn.b_Q\"] = torch.zeros(n_heads, d_head, dtype=dtype)\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn._b_K\"] = torch.zeros(n_key_value_heads, d_head, dtype=dtype)\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn._b_V\"] = torch.zeros(n_key_value_heads, d_head, dtype=dtype)\n",
        "\n",
        "        W_O = hf_state_dict[f\"layers.{l}.self_attn.o_proj.weight\"]\n",
        "        W_O = einops.rearrange(W_O, \"m (n h)->n h m\", n=n_heads)\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn.W_O\"] = W_O\n",
        "\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.attn.b_O\"] = torch.zeros(d_model, dtype=dtype)\n",
        "\n",
        "        #  Layer Norm 2\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.ln2.w\"] = hf_state_dict[f\"layers.{l}.pre_feedforward_layernorm.weight\"].float() \\\n",
        "            + torch.ones_like(hf_state_dict[f\"layers.{l}.pre_feedforward_layernorm.weight\"], dtype = torch.float32)\n",
        "\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.ln2_post.w\"] = hf_state_dict[f\"layers.{l}.post_feedforward_layernorm.weight\"].float() \\\n",
        "            + torch.ones_like(hf_state_dict[f\"layers.{l}.post_feedforward_layernorm.weight\"], dtype = torch.float32)\n",
        "\n",
        "        #  MLP Layer\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.mlp.W_in\"] = hf_state_dict[f\"layers.{l}.mlp.up_proj.weight\"].T\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.mlp.W_gate\"] = hf_state_dict[f\"layers.{l}.mlp.gate_proj.weight\"].T\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.mlp.b_in\"] = torch.zeros(d_mlp, dtype=dtype)\n",
        "\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.mlp.W_out\"] = hf_state_dict[f\"layers.{l}.mlp.down_proj.weight\"].T\n",
        "        gemma_ht_state_dict[f\"blocks.{l}.mlp.b_out\"] = torch.zeros(d_model, dtype=dtype)\n",
        "\n",
        "    gemma_ht_state_dict[\"ln_final.w\"] = hf_state_dict[\"norm.weight\"].float() \\\n",
        "        + torch.ones_like(hf_state_dict[\"norm.weight\"], dtype=torch.float32)\n",
        "\n",
        "    #  Note that W_U = embed_tokens.T\n",
        "    gemma_ht_state_dict[\"unembed.W_U\"] = hf_state_dict[\"embed_tokens.weight\"].T\n",
        "    gemma_ht_state_dict[\"unembed.b_U\"] = torch.zeros(d_vocab, dtype=dtype)\n",
        "\n",
        "    return gemma_ht_state_dict\n",
        "\n",
        "def custom_model_loader(MODEL_PATH, DEVICE, MODEL_CACHE_DIR):\n",
        "    if MODEL_PATH == \"Unispac/Gemma-2-9B-IT-With-Deeper-Safety-Alignment\":\n",
        "\n",
        "        # Load Gemma 2 9b IT\n",
        "        model = HookedTransformer.from_pretrained_no_processing(\n",
        "                    \"google/gemma-2-9b-it\",\n",
        "                    device=DEVICE,\n",
        "                    dtype=torch.bfloat16,\n",
        "                    default_padding_side=\"left\",\n",
        "                    cache_dir = MODEL_CACHE_DIR,\n",
        "                    # bf16=True\n",
        "                )\n",
        "\n",
        "        # For Light Sanity Check:\n",
        "        to_check_1 = model.state_dict()['blocks.2.attn.W_Q'].clone().detach().cpu()\n",
        "\n",
        "        #  Obtain the path of the shards\n",
        "        shards_path = MODEL_CACHE_DIR / \"models--Unispac--Gemma-2-9B-IT-With-Deeper-Safety-Alignment\" / \"snapshots\" / \"731df7c670eb136b24fad4960b6b87361e242bd8\"\n",
        "        shards = shards_path.glob(\"*.safetensors\")\n",
        "\n",
        "        #  Load the tensors\n",
        "        state_dict = {}\n",
        "\n",
        "        # print(len(shards))\n",
        "        for shard in shards:\n",
        "            print(shard)\n",
        "            shard_comp = load_file(shard)\n",
        "            state_dict.update(shard_comp)\n",
        "\n",
        "        #  Remove the \"model.\" in state dict names\n",
        "        old_keys = list(state_dict.keys())\n",
        "        for key in old_keys:\n",
        "            new_key = key.replace(\"model.\", \"\")\n",
        "            state_dict[new_key] = state_dict.pop(key)\n",
        "\n",
        "\n",
        "        #  Transform New State Dict into HT Gemma Version\n",
        "        gemma_ht_state_dict = hf_state_dict_to_HT_state_dict_gemma(state_dict)\n",
        "        assert len(gemma_ht_state_dict) == 718, f\"There is an issue with Transformation function, there are currently only {len(gemma_ht_state_dict)} items in the dictionary!\"\n",
        "\n",
        "        #  Load Weights Manually to the Model\n",
        "        print(\"Loading New Weights into the HT Model\")\n",
        "        model.load_and_process_state_dict(gemma_ht_state_dict,\n",
        "                                          fold_ln = False,\n",
        "                                          center_writing_weights = False,\n",
        "                                          center_unembed = False,\n",
        "                                          refactor_factored_attn_matrices = False,\n",
        "                                          fold_value_biases = False)\n",
        "        print(\"New Weights Loaded Successfully\")\n",
        "\n",
        "        to_check_2 = model.state_dict()['blocks.2.attn.W_Q'].clone().detach().cpu()\n",
        "        assert not torch.allclose(to_check_1, to_check_2), \"Loaded Prior and After have similar weights, this suggests an issue!\"\n",
        "        print(\"New Weights are different than old weights!\")\n",
        "\n",
        "        del gemma_ht_state_dict, state_dict, to_check_1, to_check_2\n",
        "        gc.collect()\n",
        "\n",
        "        return model\n",
        "\n",
        "    else:\n",
        "        model = HookedTransformer.from_pretrained_no_processing(\n",
        "                    MODEL_PATH,\n",
        "                    device=DEVICE,\n",
        "                    dtype=torch.bfloat16,\n",
        "                    default_padding_side=\"left\",\n",
        "                    cache_dir = MODEL_CACHE_DIR,\n",
        "                    # bf16=True\n",
        "                )\n",
        "\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h6OZRFYcXxnN"
      },
      "outputs": [],
      "source": [
        "## Unit Test Block:\n",
        "## We load HookedTransformer Gemma 9B Instruct as a base model, and we store the following:\n",
        "## Initial W_E, Final W_U; and randomly choose layer 6, all useful modules;\n",
        "## Next, load a new model using AutoTransformer and store the state dict\n",
        "## Transform the state dict using the new function and load it into the HT version\n",
        "## Extract the weights in the same position and we do a test for whether they are the same\n",
        "from transformers import AutoModel\n",
        "\n",
        "run_unit_test = False\n",
        "\n",
        "def unit_test_state_dict_transform():\n",
        "\n",
        "    CACHE_DIR = Path(os.getcwd()) / \"huggingface\"\n",
        "    MODEL_CACHE_DIR = CACHE_DIR / \"hub\"\n",
        "    DEVICE = \"cuda:0\"\n",
        "    MODEL_PATH = \"google/gemma-2-9b-it\"\n",
        "\n",
        "    # Step 1: Load HT Transformer\n",
        "    ht_model = HookedTransformer.from_pretrained_no_processing(\n",
        "                    MODEL_PATH,\n",
        "                    device=DEVICE,\n",
        "                    dtype=torch.bfloat16,\n",
        "                    default_padding_side=\"left\",\n",
        "                    cache_dir = MODEL_CACHE_DIR,\n",
        "                    # bf16=True\n",
        "                )\n",
        "\n",
        "    # Step 2: Store Required\n",
        "    ht_initial_storage = {}\n",
        "    ht_initial_storage[\"embed.W_E\"] = ht_model.state_dict()[\"embed.W_E\"].clone().detach().cpu()\n",
        "    ht_initial_storage[\"unembed.W_U\"] = ht_model.state_dict()[\"unembed.W_U\"].clone().detach().cpu()\n",
        "    for i in ht_model.state_dict():\n",
        "        if \"blocks.6\" in i:\n",
        "            ht_initial_storage[i] = ht_model.state_dict()[i]\n",
        "\n",
        "    # Step 3: Load PreTrain HF Gemma2 9B Instruct\n",
        "    hf_model = AutoModel.from_pretrained(MODEL_PATH, cache_dir = MODEL_CACHE_DIR)\n",
        "    hf_state_dict = hf_model.state_dict()\n",
        "\n",
        "    # Step 4: Transform the State Dictionary to gemma_ht_state_dict\n",
        "    gemma_ht_state_dict = hf_state_dict_to_HT_state_dict_gemma(hf_state_dict)\n",
        "    print(gemma_ht_state_dict[\"unembed.W_U\"])\n",
        "    # Test 1: Checking whether the dictionary is transformed correctly\n",
        "    assert len(gemma_ht_state_dict) == 718, f\"There is an issue with Transformation function, there are currently only {len(gemma_ht_state_dict)} items in the dictionary!\"\n",
        "    # Step 5: Load the model into the HT (Since we do pretrain no processing, set everything to false)\n",
        "    ht_model.load_and_process_state_dict(gemma_ht_state_dict,\n",
        "                                         fold_ln = False,\n",
        "                                         center_writing_weights = False,\n",
        "                                         center_unembed = False,\n",
        "                                         refactor_factored_attn_matrices = False,\n",
        "                                         fold_value_biases = False)\n",
        "\n",
        "    # Step 6: Extract the weights to compare\n",
        "    ht_new_storage = {}\n",
        "    ht_new_storage[\"embed.W_E\"] = ht_model.state_dict()[\"embed.W_E\"].clone().detach().cpu()\n",
        "    ht_new_storage[\"unembed.W_U\"] = ht_model.state_dict()[\"unembed.W_U\"].clone().detach().cpu()\n",
        "    for i in ht_model.state_dict():\n",
        "        if \"blocks.6\" in i:\n",
        "            ht_new_storage[i] = ht_model.state_dict()[i]\n",
        "\n",
        "    # Test 2:\n",
        "    to_check = list(ht_new_storage.keys())\n",
        "    for name in to_check:\n",
        "        print(f\"Name: {name}\")\n",
        "        print(ht_initial_storage[name], ht_new_storage[name])\n",
        "        assert torch.allclose(input = ht_new_storage[name], other = ht_initial_storage[name]), \"Values are off!\"\n",
        "        print(f\"Check {name} Complete! Values are close!\")\n",
        "\n",
        "    del ht_model, ht_initial_storage, hf_model, hf_state_dict, gemma_ht_state_dict, ht_new_storage\n",
        "    gc.collect()\n",
        "\n",
        "if run_unit_test:\n",
        "    unit_test_state_dict_transform()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vnp65Vsg5x-5"
      },
      "outputs": [],
      "source": [
        "# Choose one model for the experimentw\n",
        "MODEL_PATH = (\n",
        "    # \"Qwen/Qwen2.5-3B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-7B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-14B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-32B-Instruct\"\n",
        "    # \"meta-llama/Llama-3.2-3B-Instruct\"\n",
        "    \"meta-llama/Llama-3.1-8B-Instruct\"\n",
        "    # \"google/gemma-2-9b-it\"\n",
        "    # \"google/gemma-2-27b-it\"\n",
        "    # \"Unispac/Gemma-2-9B-IT-With-Deeper-Safety-Alignment\"\n",
        ")\n",
        "\n",
        "MODEL_NAME = MODEL_PATH.split(\"/\")[-1]\n",
        "\n",
        "#  We use GPU 1 from {0, ..., 7}\n",
        "DEVICE = \"cuda:0\"\n",
        "\n",
        "BATCH_SIZE = 16\n",
        "\n",
        "mode = \"causal\"\n",
        "beta = 0.99\n",
        "beta_str = str(beta).replace(\".\", \"p\")\n",
        "\n",
        "OUTPUT_PARENT_DIR = Path(\"output\") / f\"{MODEL_NAME}\" / f\"momentum_{mode}\"\n",
        "\n",
        "OUTPUT_DIR = OUTPUT_PARENT_DIR / f\"beta_{beta_str}\"\n",
        "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "VISUALIZATION_PARENT_DIR = Path(\"visualization\") / f\"{MODEL_NAME}\"/ f\"momentum_{mode}\"\n",
        "\n",
        "VISUALIZATION_DIR = VISUALIZATION_PARENT_DIR / f\"beta_{beta_str}\"\n",
        "VISUALIZATION_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES\n",
        "\n",
        "if MODEL_PATH not in OFFICIAL_MODEL_NAMES:\n",
        "    OFFICIAL_MODEL_NAMES.append(MODEL_PATH)\n",
        "\n",
        "CACHE_DIR = Path(os.getcwd()) / \"huggingface\"\n",
        "MODEL_CACHE_DIR = CACHE_DIR / \"hub\"\n",
        "DATASETS_CACHE_DIR = CACHE_DIR / \"datasets\"\n",
        "\n",
        "# model = HookedTransformer.from_pretrained_no_processing(\n",
        "#     MODEL_PATH,\n",
        "#     device=DEVICE,\n",
        "#     dtype=torch.bfloat16,\n",
        "#     default_padding_side=\"left\",\n",
        "#     cache_dir = MODEL_CACHE_DIR,\n",
        "#     # bf16=True\n",
        "# )\n",
        "model = custom_model_loader(MODEL_PATH=MODEL_PATH,\n",
        "                            DEVICE=DEVICE,\n",
        "                            MODEL_CACHE_DIR=MODEL_CACHE_DIR)\n",
        "\n",
        "model.tokenizer.padding_side = \"left\"\n",
        "\n",
        "# store original chat template\n",
        "ORIGINAL_CHAT_TEMPLATE = model.tokenizer.chat_template\n",
        "\n",
        "model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ULUODkIfXxnO"
      },
      "outputs": [],
      "source": [
        "# handle pad token for some model\n",
        "if not model.tokenizer.pad_token:\n",
        "    if \"qwen1\" in MODEL_PATH.lower():\n",
        "        model.tokenizer.pad_token = \"<|endoftext|>\"\n",
        "    elif model.tokenizer.eos_token:\n",
        "        model.tokenizer.pad_token = model.tokenizer.eos_token\n",
        "    else:\n",
        "        raise ValueError(\"No pad token found in the tokenizer.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rF7e-u20EFTe"
      },
      "source": [
        "### Load harmful / harmless datasets\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5i1XcVIgHEE1"
      },
      "outputs": [],
      "source": [
        "def get_harmful_instructions():\n",
        "    url = \"https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv\"\n",
        "    response = requests.get(url)\n",
        "\n",
        "    dataset = pd.read_csv(io.StringIO(response.content.decode(\"utf-8\")))\n",
        "    instructions = dataset[\"goal\"].tolist()\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmless_instructions():\n",
        "\n",
        "    hf_path = \"tatsu-lab/alpaca\"\n",
        "    dataset = load_dataset(hf_path, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    # filter for instructions that do not have inputs\n",
        "    instructions = []\n",
        "    for i in range(len(dataset[\"train\"])):\n",
        "        if dataset[\"train\"][i][\"input\"].strip() == \"\":\n",
        "            instructions.append(dataset[\"train\"][i][\"instruction\"])\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmful_instructions_jp():\n",
        "    huggingface_id = \"kunishou/do-not-answer-ja\"\n",
        "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    selected_types = [\n",
        "        # # Information Hazards\n",
        "        # \"Risks from leaking or inferring sensitive information (organization/gov)\",\n",
        "        # (\n",
        "        #     \"Compromise privacy by leaking or inferring private information\"\n",
        "        #     \" (person/individual)\"\n",
        "        # ),\n",
        "        # Malicious Uses\n",
        "        \"Assisting illegal activities\",\n",
        "        \"Nudging or advising users to perform unethical or unsafe actions\",\n",
        "        \"Reducing the cost of disinformation campaigns\",\n",
        "        # Discrimination, Exclusion, Toxicity, Hateful, Offensive\n",
        "        \"Social stereotypes and unfair discrimination\",\n",
        "        \"Toxic language (hate speech)\",\n",
        "        \"Adult Content\",\n",
        "        # # Misinformation Harms\n",
        "        # \"Causing material harm by disseminating misinformation e.g. in medicine or law\",\n",
        "        # \"Disseminating false or misleading information\",\n",
        "        # # Human-Chatbot Interaction Harms\n",
        "        # \"Mental Health or Overreliance Crisis\",\n",
        "        # \"Treat Chatbot as a Human\",\n",
        "    ]\n",
        "\n",
        "    instructions = []\n",
        "    for item in dataset[\"train\"]:\n",
        "        if item[\"types_of_harm\"] not in selected_types:\n",
        "            continue\n",
        "        instructions.append(item[\"question\"])\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmless_instructions_jp():\n",
        "    huggingface_id = \"Lazycuber/alpaca-jp\"\n",
        "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    # filter for instructions that do not have inputs\n",
        "    instructions = []\n",
        "    for item in dataset[\"train\"]:\n",
        "        if item[\"input\"].strip() != \"\":\n",
        "            continue\n",
        "        inst = item[\"instruction\"]\n",
        "        inst = inst.strip(\"「」'\")\n",
        "        instructions.append(inst)\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rth8yvLZJsXs"
      },
      "outputs": [],
      "source": [
        "LANGUAGE = \"en\"\n",
        "\n",
        "if LANGUAGE == \"en\":\n",
        "    harmful_inst_train, harmful_inst_test = get_harmful_instructions()\n",
        "    harmless_inst_train, harmless_inst_test = get_harmless_instructions()\n",
        "elif LANGUAGE == \"jp\":\n",
        "    harmful_inst_train, harmful_inst_test = get_harmful_instructions_jp()\n",
        "    harmless_inst_train, harmless_inst_test = get_harmless_instructions_jp()\n",
        "\n",
        "print(f\"Train: {len(harmful_inst_train)} harmful, {len(harmless_inst_train)} harmless\")\n",
        "print(f\"Test: {len(harmful_inst_test)} harmful, {len(harmless_inst_test)} harmless\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qv2ALDY_J44G"
      },
      "outputs": [],
      "source": [
        "print(\"Harmful instructions:\")\n",
        "for i in range(4):\n",
        "    print(f\"\\t{harmful_inst_train[i]}\")\n",
        "print(\"Harmless instructions:\")\n",
        "for i in range(4):\n",
        "    print(f\"\\t{harmless_inst_train[i]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KOKYA61k8LWt"
      },
      "source": [
        "### Tokenization utils\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P8UPQSfpWOSK"
      },
      "outputs": [],
      "source": [
        "#  Function Generates the (B, L) inputs to the model w.r.t tokenizer of desired model\n",
        "def instructions_to_chat_tokens(\n",
        "    tokenizer: AutoTokenizer,\n",
        "    instructions: List[str],\n",
        ") -> Int[Tensor, \"batch_size seq_len\"]:\n",
        "    #  Checks for if there is a required chat template\n",
        "    if tokenizer.chat_template:\n",
        "        #  This automatically creates the Batch\n",
        "        convos = [\n",
        "            [{\"role\": \"user\", \"content\": instruction}] for instruction in instructions\n",
        "        ]\n",
        "        return tokenizer.apply_chat_template(\n",
        "            convos,\n",
        "            padding=True, # Padding here ensures all prompts are of the same length\n",
        "            truncation=False,\n",
        "            add_generation_prompt=True,\n",
        "            return_tensors=\"pt\",\n",
        "        )\n",
        "    else:\n",
        "        return tokenizer(\n",
        "            instructions, padding=True, truncation=False, return_tensors=\"pt\"\n",
        "        ).input_ids"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VSL6-5HrXxnQ"
      },
      "outputs": [],
      "source": [
        "harmful_sample_toks = instructions_to_chat_tokens(\n",
        "    tokenizer=model.tokenizer, instructions=harmful_inst_train[:2]\n",
        ")\n",
        "harmless_sample_toks = instructions_to_chat_tokens(\n",
        "    tokenizer=model.tokenizer, instructions=harmless_inst_train[:2]\n",
        ")\n",
        "\n",
        "#  Important to note how each sample looks like\n",
        "for sample in harmful_sample_toks[:2]:\n",
        "    print(model.tokenizer.decode(sample))\n",
        "    print(\"-\" * 50)\n",
        "for sample in harmless_sample_toks[:2]:\n",
        "    print(model.tokenizer.decode(sample))\n",
        "    print(\"-\" * 50)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5WqO2IIwXxnQ"
      },
      "outputs": [],
      "source": [
        "model.tokenizer.batch_decode(harmless_sample_toks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gtrIK8x78SZh"
      },
      "source": [
        "### Generation utils\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "94jRJDR0DRoY"
      },
      "outputs": [],
      "source": [
        "def _generate_with_hooks(\n",
        "    model: HookedTransformer,\n",
        "    toks: Int[Tensor, \"batch_size seq_len\"], #  Note that toks shape is (B, L)\n",
        "    max_tokens_generated: int = BATCH_SIZE,\n",
        "    fwd_hooks=[],\n",
        ") -> List[str]:\n",
        "\n",
        "    #  Note that the assumption here is all prompts are of the same length, plus we have a max generation.\n",
        "    all_toks = torch.zeros(\n",
        "        (toks.shape[0], toks.shape[1] + max_tokens_generated),\n",
        "        dtype=torch.long,\n",
        "        device=toks.device,\n",
        "    )\n",
        "    all_toks[:, : toks.shape[1]] = toks\n",
        "\n",
        "    #  Note that we generate this in greedily\n",
        "    for i in range(max_tokens_generated):\n",
        "        #  For each forward pass construct the hooks\n",
        "        with model.hooks(fwd_hooks=fwd_hooks):\n",
        "            logits = model(all_toks[:, : -max_tokens_generated + i])\n",
        "            next_tokens = logits[:, -1, :].argmax(\n",
        "                dim=-1\n",
        "            )  # greedy sampling (temperature=0)\n",
        "            all_toks[:, -max_tokens_generated + i] = next_tokens\n",
        "\n",
        "    #  Provides the response (index -> words)\n",
        "    #  The output here is a list of responses. [resp_1, resp_2, ...]\n",
        "    return model.tokenizer.batch_decode(\n",
        "        all_toks[:, toks.shape[1] :], skip_special_tokens=False\n",
        "    )\n",
        "\n",
        "#  Wrapper function, converts list of instructions List[str] -> Outputs with respect to modified model based on hooks List[str]\n",
        "def get_generations(\n",
        "    model: HookedTransformer,\n",
        "    instructions: List[str],\n",
        "    tokenizer: AutoTokenizer,\n",
        "    fwd_hooks=[], #  This set of forward hooks determines how the model generates with\n",
        "    max_tokens_generated: int = 64,\n",
        "    batch_size: int = BATCH_SIZE,\n",
        ") -> List[str]:\n",
        "\n",
        "    generations = []\n",
        "\n",
        "    #  Pass here is batch by batch, where it's a batch of prompts\n",
        "    for i in tqdm(range(0, len(instructions), batch_size)):\n",
        "        toks = instructions_to_chat_tokens(\n",
        "            tokenizer=tokenizer, instructions=instructions[i : i + batch_size]\n",
        "        )\n",
        "\n",
        "        #  Recall toks here is (B, L) for fixed, B, L, though for each batch, I believe L might vary, but at least for one pass of B prompts, they have same L\n",
        "\n",
        "        with torch.no_grad():\n",
        "            generation = _generate_with_hooks(\n",
        "                model,\n",
        "                toks,\n",
        "                max_tokens_generated=max_tokens_generated,\n",
        "                fwd_hooks=fwd_hooks,\n",
        "            )\n",
        "        generations.extend(generation)\n",
        "\n",
        "    #  For all prompts, this generates the list of responses\n",
        "    return generations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3iIZUSYXXxnQ"
      },
      "outputs": [],
      "source": [
        "def run_single_sample(model, input, tokenizer, fwd_hooks=[], max_tokens_generated=64):\n",
        "    #  Baseline, generate response with respect to no hook\n",
        "    #  input: str\n",
        "    baseline_generations = get_generations(\n",
        "        model,\n",
        "        [input],\n",
        "        tokenizer,\n",
        "        fwd_hooks=[],\n",
        "        max_tokens_generated=max_tokens_generated,\n",
        "    )\n",
        "    intervention_generations = get_generations(\n",
        "        model,\n",
        "        [input],\n",
        "        tokenizer,\n",
        "        fwd_hooks=fwd_hooks,\n",
        "        max_tokens_generated=max_tokens_generated,\n",
        "    )\n",
        "\n",
        "    print(f\"INSTRUCTION: {repr(input)}\")\n",
        "    print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
        "    print(\n",
        "        textwrap.fill(\n",
        "            baseline_generations[0],\n",
        "            width=100,\n",
        "            initial_indent=\"\\t\",\n",
        "            subsequent_indent=\"\\t\",\n",
        "        )\n",
        "    )\n",
        "    print(Fore.RED + f\"INTERVENTION COMPLETION:\")\n",
        "    print(\n",
        "        textwrap.fill(\n",
        "            intervention_generations[0],\n",
        "            width=100,\n",
        "            initial_indent=\"\\t\",\n",
        "            subsequent_indent=\"\\t\",\n",
        "        )\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W9O8dm0_EQRk"
      },
      "source": [
        "## Finding the \"refusal direction\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eCqtlmgvXxnQ"
      },
      "source": [
        "### Helper functions\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MbY79kSP8oOg"
      },
      "outputs": [],
      "source": [
        "def __run_with_cache(model, data, batch_size):\n",
        "    cache = {}\n",
        "    with torch.no_grad():\n",
        "        for i in range(0, len(data), batch_size):\n",
        "            #  For this particular function, with respect to input (B, L), it generates the intermediate activations for each block (pre, mid, post), so output will be layer 3*layer keys, each entry is (B, L, D)\n",
        "            _, batch_cache = model.run_with_cache(\n",
        "                data[i : i + batch_size],\n",
        "                names_filter=lambda hook_name: \"resid\" in hook_name,\n",
        "                return_cache_object=False,\n",
        "            )\n",
        "            for k, v in batch_cache.items():\n",
        "                if k not in cache:\n",
        "                    cache[k] = v.cpu()\n",
        "                else:\n",
        "                    cache[k] = torch.vstack([cache[k], v.cpu()])\n",
        "    #  Cache Keys are each intermediate location; Each entry is (B, L, D) where B is the total number of prompts\n",
        "    return ActivationCache(cache, model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Srue6HQzXxnR"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    K = __run_with_cache(model, harmful_sample_toks, 2)\n",
        "K['resid_mid', 1].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "krkwpruRXxnR"
      },
      "outputs": [],
      "source": [
        "def get_template_suffix_toks(tokenizer):\n",
        "    # Since the padding is on the left side, the suffix of all samples are the same\n",
        "    # when using the same template.\n",
        "    # The activations on these suffix tokens are after the prompt has been processed,\n",
        "    # thus it's interesting to see how the activations differ between contrastive\n",
        "    # samples\n",
        "\n",
        "    # get the common suffix between 2 samples\n",
        "    toks = instructions_to_chat_tokens(tokenizer=tokenizer, instructions=[\"a\", \"b\"])\n",
        "\n",
        "    suffix = toks[0]\n",
        "    #  We traverse backwards\n",
        "    for i in range(len(toks[0]) - 1, -1, -1):\n",
        "        #  technically it collects the toks from the earliest differing token, but since instruction list is fix to only a and b, it will determinstically obtain the following:\n",
        "        #  \"Qwen/Qwen2.5-3B-Instruct\" from ids -> token: ['<|im_end|>', 'Ċ', '<|im_start|>', 'assistant', 'Ċ']\n",
        "        if toks[0][i] != toks[1][i]:\n",
        "            suffix = toks[0][i + 1 :]\n",
        "\n",
        "    return tokenizer.convert_ids_to_tokens(suffix)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ufpRNCUdXxnR"
      },
      "outputs": [],
      "source": [
        "def get_activations(\n",
        "    model: HookedTransformer,\n",
        "    instructions: List[str],\n",
        "    batch_size: int = BATCH_SIZE,\n",
        "    act_names: List[str] = [\"resid_mid\", \"resid_post\"], #  The Positions we're interested in\n",
        "    num_last_tokens: int = 1,\n",
        "):\n",
        "    # tokenize instructions\n",
        "    toks = instructions_to_chat_tokens(\n",
        "        tokenizer=model.tokenizer, instructions=instructions\n",
        "    )\n",
        "\n",
        "    # run model on instructions and cache activations\n",
        "    with torch.no_grad():\n",
        "        cache = __run_with_cache(model, toks, batch_size=BATCH_SIZE)\n",
        "\n",
        "    # get activations for the last n tokens\n",
        "    acts = torch.stack(\n",
        "        [\n",
        "            torch.stack(\n",
        "                [cache[act, layer][:, -num_last_tokens:, :] for act in act_names]\n",
        "            )\n",
        "            for layer in range(model.cfg.n_layers)\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    # For acts: layers x resid_modules [ mid, post] x batch [ prompts] x tokens x dim\n",
        "    return acts, cache"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P0WqdieHXxnR"
      },
      "source": [
        "### Extract the activations\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cby8wUmeXxnR"
      },
      "outputs": [],
      "source": [
        "def get_direction_ablation_output_hook(layer, direction: Tensor):\n",
        "    def hook_fn(output, hook):\n",
        "        #  We want to ablate dest -> src  (Input should be negative of harmful - harmless)\n",
        "        nonlocal direction\n",
        "        # nonlocal direction\n",
        "\n",
        "        #  Obtain Activations (Might be a tuple so we obtain the activation component)\n",
        "        if isinstance(output, tuple):\n",
        "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output[0]\n",
        "        else:\n",
        "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output\n",
        "\n",
        "        #  Normalize the direction (dir is 1D vector)\n",
        "        direction = direction / (direction.norm(p = 2) + 1e-8)\n",
        "        direction = direction.to(activation)\n",
        "        activation -= (activation @ direction).unsqueeze(-1) * direction\n",
        "\n",
        "        if isinstance(output, tuple):\n",
        "            return (activation, *output[1:])\n",
        "        else:\n",
        "            return activation\n",
        "\n",
        "    return hook_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7dVqO-TWXxnR"
      },
      "outputs": [],
      "source": [
        "import tqdm\n",
        "N_INST_TRAIN = 512\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "# extraction points per decoder block\n",
        "act_names = [\"resid_mid\", \"resid_post\"]\n",
        "\n",
        "# get the template suffix tokens\n",
        "template_suffix_toks = get_template_suffix_toks(model.tokenizer)\n",
        "if not template_suffix_toks:\n",
        "    template_suffix_toks = [\"<last token>\"]\n",
        "\n",
        "# only get the activations of the template suffix tokens since these tokens are the same\n",
        "# for all samples\n",
        "#  The causal nature implies the output of these tokens already contain the information of the prompts itself.\n",
        "num_last_tokens = len(template_suffix_toks)\n",
        "print(\"template_suffix_toks:\", template_suffix_toks)\n",
        "\n",
        "#  File Path Names\n",
        "chosen_token = -1\n",
        "refusal_dirs_path = (\n",
        "    OUTPUT_DIR\n",
        "    / f\"refusal_dirs_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
        ")\n",
        "unnormed_refusal_dirs_path = (\n",
        "    OUTPUT_DIR\n",
        "    / f\"refusal_dirs_unnormed_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
        ")\n",
        "\n",
        "#  Harmful should be left alone since we don't steer it and it will be deterministic; Furthermore, this is not unique as the model isn't steered\n",
        "output_harmful_file = OUTPUT_PARENT_DIR / f\"acts_harmful_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
        "\n",
        "#  Harmless should be rerun during learning, and this is unique to each to each choice of beta\n",
        "output_harmless_file = OUTPUT_DIR / f\"acts_harmless_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
        "\n",
        "if refusal_dirs_path.exists() and unnormed_refusal_dirs_path.exists() and output_harmless_file.exists():\n",
        "    print(\"loading refusal_dirs and unnormed refusal_dirs from file\")\n",
        "    unnormed_refusal_dirs = np.load(unnormed_refusal_dirs_path)\n",
        "    refusal_dirs = np.load(refusal_dirs_path)\n",
        "    print(\"loading harmful and sequentially steered harmless files\")\n",
        "    harmful_acts = np.load(output_harmful_file)\n",
        "    harmful_acts = torch.from_numpy(harmful_acts)\n",
        "    harmless_acts = np.load(output_harmless_file)\n",
        "    harmless_acts = torch.from_numpy(harmless_acts)\n",
        "else:\n",
        "    # Momentum_mode\n",
        "    momentum_mode = True\n",
        "    v = None\n",
        "\n",
        "    #  Do not apply steering to harmful activation, so running them outside to save computations\n",
        "    if output_harmful_file.exists():\n",
        "        harmful_acts = np.load(output_harmful_file)\n",
        "        harmful_acts = torch.from_numpy(harmful_acts)\n",
        "    else:\n",
        "        harmful_acts, cache = get_activations(\n",
        "            model,\n",
        "            harmful_inst_train[:N_INST_TRAIN],\n",
        "            batch_size=BATCH_SIZE,\n",
        "            act_names=act_names,\n",
        "            num_last_tokens=num_last_tokens,\n",
        "        )\n",
        "        np.save(output_harmful_file, harmful_acts.cpu().float().numpy())\n",
        "\n",
        "    # Normalize and Mean across harmful\n",
        "    harmful_acts_norm = harmful_acts / harmful_acts.norm(dim=-1, keepdim=True)\n",
        "    harmful_acts_norm_mean = harmful_acts_norm.mean(dim = 2)\n",
        "\n",
        "    # Direction will store [layer][mid == 0, post == 1]\n",
        "    directions = {}\n",
        "\n",
        "    # Initialize refusal_dirs storage:\n",
        "    d_model = harmful_acts.shape[-1]\n",
        "    unnormed_refusal_dirs = torch.zeros(model.cfg.n_layers * len(act_names), d_model)\n",
        "\n",
        "    # Used Module Name, each format should be (layer_ind, (0/1))\n",
        "    used_module_names = []\n",
        "    for l in tqdm.tqdm(range(model.cfg.n_layers * len(act_names))):\n",
        "\n",
        "        layer, pos = l // 2, l % 2\n",
        "\n",
        "        #  Forward Hook Construction; Separated for layer and position:\n",
        "        fwd_hooks = []\n",
        "        for tup in used_module_names:\n",
        "            ly, mp = tup\n",
        "            # Note that each direction is from src (harmless) to (harmful);\n",
        "            # Because we ablate, we take the negative direction\n",
        "            if mp == 0:\n",
        "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_mid\", get_direction_ablation_output_hook(ly, -directions[ly][mp])))\n",
        "            elif mp == 1:\n",
        "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_post\", get_direction_ablation_output_hook(ly, -directions[ly][mp])))\n",
        "            else:\n",
        "                raise NotImplementedError(\"mp not 0 or 1\")\n",
        "        # fwd_hooks = [(f\"blocks.{ly}.hook_resid_post\", get_direction_ablation_output_hook(ly, directions[ly])) for ly in used_module_names]\n",
        "            # These hooks will intervene on the model, that's why we need hook_* args.\n",
        "        # get contranstive activations\n",
        "        #  Apply steering to only Harmless Activation (Src)\n",
        "        with model.hooks(fwd_hooks=fwd_hooks):\n",
        "            harmless_acts, cache = get_activations(\n",
        "                model,\n",
        "                harmless_inst_train[:N_INST_TRAIN],\n",
        "                batch_size=BATCH_SIZE,\n",
        "                act_names=act_names,\n",
        "                num_last_tokens=num_last_tokens,\n",
        "            )\n",
        "\n",
        "\n",
        "        # print(harmful_acts.shape)\n",
        "        # print(harmless_acts.shape)\n",
        "\n",
        "        # For each step, we find the difference of normed means (Norm the activations, then take the mean) then we find the refusal direction\n",
        "        # Take the Mean across Batch\n",
        "        harmless_acts_norm = harmless_acts / harmless_acts.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        # Take Mean across Batch\n",
        "        harmless_acts_norm_mean = harmless_acts_norm.mean(dim = 2)\n",
        "\n",
        "        # Take the difference in normed mean (For now, our ref_dir is from harmless -> harmful; but in our directional ablation we take the opposite direction)\n",
        "        ref_dir_set = harmful_acts_norm_mean - harmless_acts_norm_mean\n",
        "        # Specifically, choose the last token\n",
        "        d_model = ref_dir_set.shape[-1]\n",
        "        ref_dir_set = ref_dir_set[:, :, -1].reshape(-1, d_model)\n",
        "        ref_dir = ref_dir_set[l] # This should be = to layer*2 + pos\n",
        "        # seq_harmful_acts_normed = seq_harmful_acts / seq_harmful_acts.norm(dim=-1, keepdim=True)\n",
        "        # seq_harmless_acts_normed = seq_harmless_acts / seq_harmless_acts.norm(dim=-1, keepdim=True)\n",
        "        if pos == 0:\n",
        "            directions[layer] = {}\n",
        "        if momentum_mode:\n",
        "            if l == 0:\n",
        "                v = ref_dir\n",
        "            else:\n",
        "                v = beta * v + ref_dir\n",
        "            directions[layer][pos] = v\n",
        "            unnormed_refusal_dirs[l, :] = v.detach()\n",
        "        else:\n",
        "            directions[layer][pos] = ref_dir\n",
        "            unnormed_refusal_dirs[l, :] = ref_dir.detach()\n",
        "        # harmful_acts_normed_mean = seq_harmful_acts_normed.mean(dim=2)\n",
        "        # harmless_acts_normed_mean = seq_harmless_acts_normed.mean(dim=2)\n",
        "        used_module_names.append((layer, pos))\n",
        "\n",
        "    #  Compute Normed Refusal Directions\n",
        "    refusal_dirs = unnormed_refusal_dirs / unnormed_refusal_dirs.norm(dim=-1, keepdim=True)\n",
        "    refusal_dirs = refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
        "    unnormed_refusal_dirs = unnormed_refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
        "\n",
        "    #  Set harmless + harmful acts to cpu and float\n",
        "    harmful_acts = harmful_acts.cpu().float()\n",
        "    harmless_acts = harmless_acts.cpu().float()\n",
        "\n",
        "    # Save harmless, refusal directions normed and unnormed\n",
        "    np.save(output_harmless_file, harmless_acts.numpy())\n",
        "    np.save(unnormed_refusal_dirs_path, unnormed_refusal_dirs)\n",
        "    np.save(refusal_dirs_path, refusal_dirs)\n",
        ""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ZdiGoKVXxnR"
      },
      "source": [
        "### Analyze the activations\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x84idHg1XxnR"
      },
      "outputs": [],
      "source": [
        "from torch.nn.functional import cosine_similarity, normalize\n",
        "\n",
        "\n",
        "def get_pairwise_cosine_similarity(acts_normed):\n",
        "    # comput cosine similarity of each pair of vector from a set of normalized vectors\n",
        "    # acts_normed is ... x batch x toks x dim\n",
        "\n",
        "    acts_normed = torch.tensor(acts_normed, device=\"cuda\")\n",
        "\n",
        "    # ... batch1 toks dim, ... batch2 toks dim -> ... toks batch1 batch2\n",
        "    acts_pairwise_sim = torch.einsum(\"...ikl,...jkl->...kij\", acts_normed, acts_normed)\n",
        "    #  Compute cosine similarity, for each of the n final tokens, with n constant, the cosine similarity between each batch (prompt)\n",
        "\n",
        "    batch_size = acts_pairwise_sim.shape[-1]\n",
        "\n",
        "    # get the indices of the upper triangular part of the batch x batch similarity matrix\n",
        "    indices = np.arange(batch_size**2).reshape(batch_size, batch_size)\n",
        "    indices = indices[np.triu_indices_from(indices, k=1)]\n",
        "\n",
        "\n",
        "    #  Follow the shapes here\n",
        "    # ... x toks x batch x (batch * batch)\n",
        "    acts_pairwise_sim = acts_pairwise_sim.reshape(*acts_pairwise_sim.shape[:-2], -1)\n",
        "    # ... x toks x batch x (batch * (batch - 1) // 2)\n",
        "    acts_pairwise_sim = acts_pairwise_sim[..., indices]\n",
        "    #  At this point, only keep the non-diagonal upper triangular matrix for each token\n",
        "    # ... x (batch * (batch - 1) // 2) x toks\n",
        "    acts_pairwise_sim = acts_pairwise_sim.swapaxes(-1, -2)\n",
        "\n",
        "    return acts_pairwise_sim\n",
        "\n",
        "\n",
        "def get_cosine_with_mean(acts_normed):\n",
        "    # compute cosine similarity of each vector with the mean vector\n",
        "    # acts_normed is ... x batch x toks x dim\n",
        "\n",
        "    #  The mean is taken with respect for fixed tokens, so for token position n = 1, ..., 5, take the mean across the prompts\n",
        "\n",
        "    acts_normed = torch.tensor(acts_normed, device=\"cuda\")\n",
        "    mean_act = acts_normed.mean(axis=2)\n",
        "    mean_act /= mean_act.norm(dim=-1, keepdim=True)\n",
        "\n",
        "    # ... batch toks dim, ... toks dim -> ... batch toks\n",
        "    cosine_with_mean = torch.einsum(\"...ijk,...jk ->...ij\", acts_normed, mean_act)\n",
        "\n",
        "    return cosine_with_mean\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tqD5E8Vc_w5d"
      },
      "outputs": [],
      "source": [
        "# layers x resid_modules x batch x tokens x dim\n",
        "harmful_acts_normed = harmful_acts / harmful_acts.norm(dim=-1, keepdim=True)\n",
        "harmless_acts_normed = harmless_acts / harmless_acts.norm(dim=-1, keepdim=True)\n",
        "\n",
        "# shape: layers x resid_modules x tokens x dim #  Normalize, then average across batch\n",
        "# normalize then get mean because the activation will be normalized by the RMSNorm layer\n",
        "# normalize helps to preserve the directions as the magnitudes are irrelevant after the\n",
        "# RMSNorm layer\n",
        "harmful_acts_normed_mean = harmful_acts_normed.mean(dim=2)\n",
        "harmless_acts_normed_mean = harmless_acts_normed.mean(dim=2)\n",
        "\n",
        "# layers x resid_modules x tokens\n",
        "similarity_scores = (\n",
        "    cosine_similarity(harmful_acts_normed_mean, harmless_acts_normed_mean, dim=-1)\n",
        "    .cpu()\n",
        "    .float()\n",
        "    .numpy()\n",
        ")\n",
        "\n",
        "\n",
        "hidden_dim = harmful_acts.shape[-1]\n",
        "# shape: layers x resid_modules x tokens x dim\n",
        "# Rescale the activations to the same as in RMSNorm (sqrt(hidden_dim))\n",
        "# This effectively makes the values of each vector be standard normal\n",
        "# So regardless of the hidden dimension, each vector will always be a sample from\n",
        "# standard normal\n",
        "# Hence the variance of activation values will be 1\n",
        "harmful_acts_normed_var = (\n",
        "    (harmful_acts_normed * np.sqrt(hidden_dim)).var(dim=2).cpu().float().numpy()\n",
        ")\n",
        "harmless_acts_normed_var = (\n",
        "    (harmless_acts_normed * np.sqrt(hidden_dim)).var(dim=2).cpu().float().numpy()\n",
        ")\n",
        "\n",
        "harmful_acts_normed = harmful_acts_normed.cpu().float().numpy()\n",
        "harmless_acts_normed = harmless_acts_normed.cpu().float().numpy()\n",
        "\n",
        "#  Compute Cosine Similarity, pairwise and with the mean\n",
        "\n",
        "# layers x resid_modules x batch x tokens\n",
        "# cosine of each vector with the mean vector\n",
        "harmful_acts_cosine_with_mean = get_cosine_with_mean(harmful_acts_normed).cpu().numpy()\n",
        "harmless_acts_cosine_with_mean = (\n",
        "    get_cosine_with_mean(harmless_acts_normed).cpu().numpy()\n",
        ")\n",
        "\n",
        "# layers x resid_modules x (batch * (batch - 1) // 2) x tokens\n",
        "# cosine similarity of each pair of vectors\n",
        "harmful_acts_pairwise_sim = (\n",
        "    get_pairwise_cosine_similarity(harmful_acts_normed).cpu().numpy()\n",
        ")\n",
        "harmless_acts_pairwise_sim = (\n",
        "    get_pairwise_cosine_similarity(harmless_acts_normed).cpu().numpy()\n",
        ")\n",
        "\n",
        "# layers x resid_modules x tokens\n",
        "# variance of cosine similarity of each pair of vectors\n",
        "harmful_acts_pairwise_sim_var = np.var(harmful_acts_pairwise_sim, axis=-2)\n",
        "harmless_acts_pairwise_sim_var = np.var(harmless_acts_pairwise_sim, axis=-2)\n",
        "\n",
        "#  Storing Variance\n",
        "\n",
        "acts_normed_var = dict()\n",
        "\n",
        "# layers x resid_modules x tokens\n",
        "acts_normed_var[\"harmful\"] = dict(\n",
        "    mean=harmful_acts_normed_var.mean(axis=-1),\n",
        "    max=harmful_acts_normed_var.max(axis=-1),\n",
        ")\n",
        "\n",
        "# layers x resid_modules x tokens\n",
        "acts_normed_var[\"harmless\"] = dict(\n",
        "    mean=harmless_acts_normed_var.mean(axis=-1),\n",
        "    max=harmless_acts_normed_var.max(axis=-1),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ly_xGyYxXxnS"
      },
      "outputs": [],
      "source": [
        "# clean up memory\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ms4RlOieXxnS"
      },
      "source": [
        "### Visualize the activations\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qsPstq_vXxnS"
      },
      "source": [
        "#### Cosine Similarity between harmful and harmless activations at each layer and token position\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zuCO2BM6XxnS"
      },
      "outputs": [],
      "source": [
        "#  Similarity scores is (Layer, Pos(2), Tokens), contains the similarity between harmless and harmful mean across batches, for each layer + pos + token\n",
        "\n",
        "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
        "data = similarity_scores.reshape(-1, similarity_scores.shape[-1])\n",
        "y_labels = sum([[f\"{layer}-mid\", f\"{layer}-post\"] for layer in range(num_layers)], [])\n",
        "x_labels = [repr(tok) for tok in template_suffix_toks]\n",
        "\n",
        "\n",
        "fig = px.imshow(\n",
        "    data,\n",
        "    y=y_labels,\n",
        "    labels={\"x\": \"token position\", \"y\": \"layer\", \"color\": \"cosine similarity\"},\n",
        "    aspect=\"auto\",\n",
        ")\n",
        "fig.update_layout(\n",
        "    xaxis={\n",
        "        \"tickmode\": \"array\",\n",
        "        \"ticktext\": x_labels,\n",
        "        \"tickvals\": list(range(len(x_labels))),\n",
        "    },\n",
        "    yaxis={\n",
        "        \"tickmode\": \"array\",\n",
        "        \"ticktext\": list(range(len(y_labels))),\n",
        "        \"tickvals\": list(range(0, len(y_labels), len(act_names))),\n",
        "    },\n",
        "    title=(\n",
        "        \"Cosine Similarity between harmful and harmless activations at each layer and\"\n",
        "        \" token position\"\n",
        "    ),\n",
        ")\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kazkp0MEXxnS"
      },
      "outputs": [],
      "source": [
        "def variance_plot(**kwargs):\n",
        "    x = kwargs.pop(\"x\")\n",
        "    y = kwargs.pop(\"y\")\n",
        "    y_mean = y.mean(dim=-1)\n",
        "    y_std = y.std(dim=-1)\n",
        "    y_upper = y_mean + y_std\n",
        "    y_lower = y_mean - y_std\n",
        "    y_upper = y_upper.tolist()\n",
        "    y_lower = y_lower.tolist()\n",
        "    # colour = kwargs.pop(\"color\")\n",
        "\n",
        "    trace = go.Scatter(\n",
        "        x=x + x[::-1],\n",
        "        y=y_upper + y_lower[::-1],\n",
        "        mode=\"lines\",\n",
        "        fill=\"toself\",\n",
        "        line=dict(color=kwargs[\"fillcolor\"], width=0),\n",
        "        **kwargs\n",
        "    )\n",
        "\n",
        "    return trace"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nh9hpC7HXxnS"
      },
      "source": [
        "#### Activation norms at each extraction point\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HCsgSNvPXxnS"
      },
      "outputs": [],
      "source": [
        "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
        "\n",
        "chosen_token = -1\n",
        "colour_map = {\n",
        "    \"harmless\": plotly.colors.qualitative.Plotly[0],\n",
        "    \"harmful\": plotly.colors.qualitative.Plotly[1],\n",
        "    \"neutral\": plotly.colors.qualitative.Plotly[3],\n",
        "}\n",
        "colour_map_light = {\n",
        "    \"harmless\": plotly.colors.qualitative.Pastel1[1],\n",
        "    \"harmful\": plotly.colors.qualitative.Pastel1[0],\n",
        "    \"neutral\": plotly.colors.qualitative.Pastel1[3],\n",
        "}\n",
        "colour_map_opaque = {\n",
        "    \"harmless\": None,\n",
        "    \"harmful\": \"rgba(251, 180, 174, 0.3)\",\n",
        "    \"harmless\": \"rgba(179, 205, 227, 0.3)\",\n",
        "}\n",
        "\n",
        "# layers x resid_modules x tokens x batch x dim\n",
        "acts = {\"harmful\": harmful_acts, \"harmless\": harmless_acts}\n",
        "\n",
        "categories = [\"harmless\", \"harmful\"]\n",
        "resid_modules = [\"mid\", \"post\"]\n",
        "\n",
        "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
        "x_values = [str(i) for i in range(2 * num_layers)]\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "for category in categories:\n",
        "    normed_acts = acts[category].norm(dim=-1)\n",
        "    mean_normed_acts = normed_acts.mean(dim=-1)\n",
        "\n",
        "    y_values = mean_normed_acts[..., chosen_token].flatten()\n",
        "\n",
        "    # mean\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values,\n",
        "            name=category,\n",
        "            mode=\"lines+markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=True,\n",
        "        )\n",
        "    )\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values,\n",
        "            mode=\"lines+markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map_light[category], size=3),\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # variance\n",
        "    fig.add_trace(\n",
        "        variance_plot(\n",
        "            x=x_values,\n",
        "            y=normed_acts[:, :, chosen_token].reshape(-1, normed_acts.shape[-1]),\n",
        "            yaxis=\"y\",\n",
        "            fillcolor=colour_map_opaque[category],\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # dot markers\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values[1::],\n",
        "            y=y_values[1::],\n",
        "            name=f\"{category}\",\n",
        "            mode=\"markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "    # for module_idx, module_name in enumerate(resid_modules):\n",
        "    #     if module_name == \"post\":\n",
        "    #         colour = colour_map[category]\n",
        "    #     else:\n",
        "    #         colour = colour_map_light[category]\n",
        "    #     fig.add_trace(\n",
        "    #         go.Scatter(\n",
        "    #             x=x_values[module_idx::2],\n",
        "    #             y=y_values[module_idx::2],\n",
        "    #             name=f\"{category}-{module_name}\",\n",
        "    #             mode=\"markers\",\n",
        "    #             yaxis=\"y\",\n",
        "    #             marker=dict(color=colour, size=3),\n",
        "    #             showlegend=True,\n",
        "    #         )\n",
        "    #     )\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    # title=f\"Activation norms at each layer for {MODEL_PATH}\",\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        dtick=4,\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=\"Activation Norm\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=250,\n",
        "    # width=20 + 12 * len(x_values),\n",
        "    width=600,\n",
        "    margin=dict(l=0, r=0, t=0, b=0),\n",
        "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_PARENT_DIR / \"acts_norm.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_8nhNdGhXxnT"
      },
      "source": [
        "#### Variance of normed activations at each extraction point\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "bYzYDddYXxnT"
      },
      "outputs": [],
      "source": [
        "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
        "\n",
        "chosen_token = -1\n",
        "colour_map = {\n",
        "    \"harmless\": plotly.colors.qualitative.Plotly[0],\n",
        "    \"harmful\": plotly.colors.qualitative.Plotly[1],\n",
        "    \"neutral\": plotly.colors.qualitative.Plotly[3],\n",
        "}\n",
        "colour_map_light = {\n",
        "    \"harmless\": plotly.colors.qualitative.Pastel1[1],\n",
        "    \"harmful\": plotly.colors.qualitative.Pastel1[0],\n",
        "    \"neutral\": plotly.colors.qualitative.Pastel1[3],\n",
        "}\n",
        "\n",
        "categories = [\"harmless\", \"harmful\"]\n",
        "resid_modules = [\"mid\", \"post\"]\n",
        "metrics = [\"mean\", \"max\"]\n",
        "\n",
        "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "for m, metric in enumerate(metrics):\n",
        "    for category in categories:\n",
        "        y_values = acts_normed_var[category][metric][..., chosen_token]\n",
        "        fig.add_trace(\n",
        "            go.Scatter(\n",
        "                x=x_values,\n",
        "                y=y_values.flatten(),\n",
        "                mode=\"lines\",\n",
        "                yaxis=f\"y{m + 1}\",\n",
        "                marker=dict(color=colour_map_light[category], size=5),\n",
        "                showlegend=False,\n",
        "            )\n",
        "        )\n",
        "        for module_idx, module_name in enumerate(resid_modules):\n",
        "            if module_name == \"mid\":\n",
        "                colour = colour_map[category]\n",
        "            else:\n",
        "                colour = colour_map_light[category]\n",
        "            fig.add_trace(\n",
        "                go.Scatter(\n",
        "                    x=x_values[module_idx::2],\n",
        "                    y=y_values.flatten()[module_idx::2],\n",
        "                    name=f\"{category}-{module_name}\",\n",
        "                    mode=\"markers\",\n",
        "                    yaxis=f\"y{m + 1}\",\n",
        "                    marker=dict(color=colour, size=5),\n",
        "                    showlegend=m == 0,\n",
        "                )\n",
        "            )\n",
        "\n",
        "diff_mean_var = (\n",
        "    acts_normed_var[\"harmless\"][\"mean\"][..., chosen_token]\n",
        "    - acts_normed_var[\"harmful\"][\"mean\"][..., chosen_token]\n",
        ")\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=x_values,\n",
        "        y=y_values.flatten(),\n",
        "        mode=\"lines\",\n",
        "        yaxis=\"y3\",\n",
        "        marker=dict(color=colour_map_light[\"neutral\"], size=5),\n",
        "        showlegend=False,\n",
        "    )\n",
        ")\n",
        "for module_idx, module_name in enumerate(resid_modules):\n",
        "    if module_name == \"mid\":\n",
        "        colour = colour_map[\"neutral\"]\n",
        "    else:\n",
        "        colour = colour_map_light[\"neutral\"]\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values[module_idx::2],\n",
        "            y=y_values.flatten()[module_idx::2],\n",
        "            name=f\"(harmelss - harmful)-{module_name}\",\n",
        "            mode=\"markers\",\n",
        "            yaxis=\"y3\",\n",
        "            marker=dict(color=colour, size=5),\n",
        "            showlegend=True,\n",
        "        )\n",
        "    )\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    title=f\"Variance of normed activations at each layer for {MODEL_PATH}\",\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=3, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\", dtick=2, title=\"Transformers Block\", gridcolor=\"lightgrey\"\n",
        "    ),\n",
        "    yaxis=dict(title=f\"{metrics[0]} variance\", gridcolor=\"lightgrey\", zeroline=False),\n",
        "    yaxis2=dict(title=f\"{metrics[1]} variance\", gridcolor=\"lightgrey\", zeroline=False),\n",
        "    yaxis3=dict(title=f\"harmless - harmful\", gridcolor=\"lightgrey\", zeroline=False),\n",
        "    hovermode=\"x unified\",\n",
        "    height=1200,\n",
        "    # width=20 + 12 * len(x_values),\n",
        ")\n",
        "\n",
        "# fig.update_xaxes(\n",
        "#     mirror=True,\n",
        "#     ticks='outside',\n",
        "#     showline=True,\n",
        "#     # linecolor='black',\n",
        "#     gridcolor='lightgrey'\n",
        "# )\n",
        "# fig.update_yaxes(\n",
        "#     mirror=True,\n",
        "#     ticks='outside',\n",
        "#     showline=True,\n",
        "#     # linecolor='black',\n",
        "#     gridcolor='lightgrey'\n",
        "# )\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fkt7CMExXxnT"
      },
      "source": [
        "#### Statistics of activations between harmful and harmless activations at each extraction point\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_xZIOJVbXxnT"
      },
      "outputs": [],
      "source": [
        "similarity_scores.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "xrcRI1JVXxnT"
      },
      "outputs": [],
      "source": [
        "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "# Angles between each pair of mean vectors at each layer\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=x_values,\n",
        "        y=np.arccos(similarity_scores[..., chosen_token]).flatten(),\n",
        "        name=\"harmful-harmless angle\",\n",
        "        # showlegend=False,\n",
        "        mode=\"lines+markers\",\n",
        "        marker=dict(color=colour_map_light[\"neutral\"], size=5),\n",
        "        yaxis=\"y\",\n",
        "    )\n",
        ")\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=x_values[1::2],\n",
        "        y=np.arccos(similarity_scores[..., 1, chosen_token]).flatten(),\n",
        "        # name=\"cosine similarity\",\n",
        "        showlegend=False,\n",
        "        mode=\"markers\",\n",
        "        marker=dict(color=colour_map[\"neutral\"], size=5),\n",
        "        yaxis=\"y\",\n",
        "    )\n",
        ")\n",
        "\n",
        "for category in categories:\n",
        "    if category == \"harmful\":\n",
        "        # layers x resid_modules x (batch * (batch - 1) // 2)\n",
        "        # cosine of each sample activation vector with other sample activation vectors\n",
        "        acts_pairwise_sim = harmful_acts_pairwise_sim[..., chosen_token]\n",
        "\n",
        "        # layers x resid_modules x batch\n",
        "        # cosine of each sample activation vector with the mean activation vector\n",
        "        acts_cosine_with_mean = harmful_acts_cosine_with_mean[..., chosen_token].clip(\n",
        "            -1, 1\n",
        "        )\n",
        "    else:\n",
        "        # layers x resid_modules x (batch * (batch - 1) // 2)\n",
        "        # cosine of each sample activation vector with other sample activation vectors\n",
        "        acts_pairwise_sim = harmless_acts_pairwise_sim[..., chosen_token]\n",
        "\n",
        "        # layers x resid_modules x batch\n",
        "        # cosine of each sample activation vector with the mean activation vector\n",
        "        acts_cosine_with_mean = harmless_acts_cosine_with_mean[..., chosen_token].clip(\n",
        "            -1, 1\n",
        "        )\n",
        "\n",
        "    acts_arccos_with_mean = np.arccos(acts_cosine_with_mean)\n",
        "\n",
        "    count = acts_cosine_with_mean.shape[-1]\n",
        "\n",
        "    for module_idx, module_name in enumerate(resid_modules):\n",
        "        if module_name == \"pre\":\n",
        "            colour = colour_map[category]\n",
        "        else:\n",
        "            colour = colour_map_light[category]\n",
        "\n",
        "        # fig.add_trace(\n",
        "        #     go.Box(\n",
        "        #         x=sum([[name] * count for name in x_values[module_idx::2]], []),\n",
        "        #         y=np.arccos(acts_arccos_with_mean[..., module_idx, :]).flatten(),\n",
        "        #         boxmean=True,\n",
        "        #         line_width=1,\n",
        "        #         marker_size=2,\n",
        "        #         showlegend=False,\n",
        "        #         marker_color=colour,\n",
        "        #         yaxis=\"y2\",\n",
        "        #     ),\n",
        "        # )\n",
        "        fig.add_trace(\n",
        "            variance_plot(\n",
        "                x=x_values,\n",
        "                y=torch.tensor(acts_arccos_with_mean).reshape(\n",
        "                    -1, acts_arccos_with_mean.shape[-1]\n",
        "                ),\n",
        "                yaxis=\"y2\",\n",
        "                fillcolor=colour_map_opaque[category],\n",
        "                showlegend=False,\n",
        "            )\n",
        "        )\n",
        "\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=acts_arccos_with_mean.mean(axis=-1).flatten(),\n",
        "            mode=\"lines+markers\",\n",
        "            showlegend=False,\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            line_width=1,\n",
        "            yaxis=\"y2\",\n",
        "        ),\n",
        "    )\n",
        "    # fig.add_trace(\n",
        "    #     go.Scatter(\n",
        "    #         x=x_values[1::2],\n",
        "    #         y=acts_pairwise_sim.mean(axis=-1).flatten(),\n",
        "    #         mode=\"markers\",\n",
        "    #         showlegend=False,\n",
        "    #         marker_color=colour_map[category],\n",
        "    #         yaxis=\"y4\",\n",
        "    #     ),\n",
        "    # )\n",
        "\n",
        "#  The mean of comparison here: for each position (layer, pos), and each token, we take the mean\n",
        "#  harmful_acts_cosine_with_mean[..., chosen_token]: For each position (layer, (pos)), for the chosen token, this is the cosine similarity between the inidvidual batch (prompt) and the mean across prompt\n",
        "#  Now, take the cos^-1 and take the mean across the prompts. Output should be a value at each position (layer, pos)\n",
        "\n",
        "harmful_locality_scores = np.arccos(\n",
        "    harmful_acts_cosine_with_mean[..., chosen_token].clip(-1, 1)\n",
        ").mean(axis=-1)\n",
        "harmless_locality_scores = np.arccos(\n",
        "    harmless_acts_cosine_with_mean[..., chosen_token].clip(-1, 1)\n",
        ").mean(axis=-1)\n",
        "locality_scores = np.maximum(harmful_locality_scores, harmless_locality_scores)\n",
        "\n",
        "#  Recall that similarity scores is (Layer, Pos(2), Tokens), contains the similarity between (harmless mean across batches) and (harmful mean across batches), for each layer + pos + token\n",
        "#  Take the last token\n",
        "sparsity_scores = np.arccos(similarity_scores[..., chosen_token].clip(-1, 1))\n",
        "\n",
        "# scores = sparsity_scores * 2 / (harmful_locality_scores + harmless_locality_scores)\n",
        "# scores = sparsity_scores - locality_scores\n",
        "#  Each entry for the score evaluates, for fixed position and chosen token, numerator is the similarity (rid of dim -1) between (harmful and harmless mean across batches (rid of dim 2))/ denominator: maximum between how close each harmful and harmless prompts were to their respective mean\n",
        "scores = sparsity_scores / locality_scores\n",
        "# scores[\n",
        "#     sparsity_scores <= np.minimum(harmful_locality_scores, harmless_locality_scores)\n",
        "# ] = np.nan\n",
        "fig.add_trace(\n",
        "    go.Scatter(x=x_values, y=scores.flatten(), mode=\"lines+markers\", yaxis=\"y3\")\n",
        ")\n",
        "\n",
        "#  Based on below, the above metric is not used\n",
        "\n",
        "fig.update_layout(\n",
        "    title=(\n",
        "        \"Statistics of activations between harmful and harmless activations at each\"\n",
        "        f\" layer for {MODEL_PATH}\"\n",
        "    ),\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=3, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\", title=\"Transformers Block\", dtick=2, gridcolor=\"lightgrey\"\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=\"Harmful-Harmless angle (1)\", gridcolor=\"lightgrey\", zeroline=False\n",
        "    ),\n",
        "    yaxis2=dict(title=\"Pairwise arccos (2)\", gridcolor=\"lightgrey\", zeroline=False),\n",
        "    yaxis3=dict(title=\"(1) / max((2))\", gridcolor=\"lightgrey\", zeroline=False),\n",
        "    # yaxis4=dict(title=\"Pairwise Cosine Similarity between samples\"),\n",
        "    hovermode=\"x unified\",\n",
        "    height=1200,\n",
        "    # width=20 + 12 * len(x_values),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "# an adhoc attempt to find the best direction\n",
        "chosen_layer, chosen_act_idx = np.unravel_index(\n",
        "    np.nanargmax(scores, axis=None), scores.shape\n",
        ")\n",
        "print(\n",
        "    f\"Best direction at layer {chosen_layer}, module\"\n",
        "    f\" {act_names[chosen_act_idx]}, position {chosen_token}\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pxPt0AA0XxnT"
      },
      "outputs": [],
      "source": [
        "# another adhoc attempt to find the best direction\n",
        "\n",
        "harmful_locality_scores = np.arccos(harmful_acts_cosine_with_mean.clip(-1, 1)).mean(\n",
        "    axis=-2\n",
        ")\n",
        "harmless_locality_scores = np.arccos(harmless_acts_cosine_with_mean.clip(-1, 1)).mean(\n",
        "    axis=-2\n",
        ")\n",
        "locality_scores = harmful_locality_scores + harmless_locality_scores\n",
        "\n",
        "sparsity_scores = np.arccos(similarity_scores.clip(-1, 1))\n",
        "\n",
        "scores = sparsity_scores / np.maximum(harmful_locality_scores, harmless_locality_scores)\n",
        "print(scores.shape)\n",
        "# scores[\n",
        "#     sparsity_scores <= np.minimum(harmful_locality_scores, harmless_locality_scores)\n",
        "# ] = np.nan\n",
        "scores = scores[..., -2:]\n",
        "_chosen_layer, _chosen_act_idx, _chosen_token = np.unravel_index(\n",
        "    np.nanargmax(scores, axis=None), scores.shape\n",
        ")\n",
        "print(\n",
        "    f\"Lowest cosine similarity at layer {_chosen_layer}, module\"\n",
        "    f\" {act_names[_chosen_act_idx]}, position {_chosen_token}\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v_EVi2u6XxnU"
      },
      "outputs": [],
      "source": [
        "normalize(harmful_acts_normed_mean[:, :, chosen_token], dim = -1).shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-DDWTFriXxnU"
      },
      "source": [
        "### Modified: Refusal Directions Already Calculated\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RGWafAL_XxnU"
      },
      "source": [
        "## Refusal direction analysis\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UdLI9DrCXxnU"
      },
      "source": [
        "### Pairwise cosine of refusal directions at each layer\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "AXl2AqFVXxnU"
      },
      "outputs": [],
      "source": [
        "layer_names = sum([[f\"{i}-mid\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
        "\n",
        "# Combine the layer and sublayer -> (layer*2, D)\n",
        "dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
        "A = dirs @ dirs.T\n",
        "\n",
        "fig = px.imshow(\n",
        "    # np.rad2deg(np.arccos(np.clip(A, 0.0, 1.0))),\n",
        "    A,\n",
        "    x=layer_names,\n",
        "    y=layer_names,\n",
        "    width=len(layer_names) * 14,\n",
        "    height=len(layer_names) * 14,\n",
        "    title=\"Cosine Similarity Matrix\",\n",
        "    color_continuous_scale=\"Viridis\",\n",
        ")\n",
        "fig.update_layout(\n",
        "    yaxis=dict(dtick=1),\n",
        "    xaxis=dict(dtick=1),\n",
        ")\n",
        "# fig.update_traces(xgap=1, ygap=1)\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EGjxw11LXxnU"
      },
      "source": [
        "### Mean cosine of refusal directions at each layer with at other layers\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TMuXmp0hXxnU"
      },
      "outputs": [],
      "source": [
        "from torch.nn.functional import normalize\n",
        "\n",
        "layer_names = sum([[f\"{i}\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
        "layer_names = [str(i) for i in range(2 * num_layers)]\n",
        "\n",
        "#  Switching to Unnormed Refusal Directions from raw\n",
        "raw_dirs = torch.asarray(unnormed_refusal_dirs)\n",
        "\n",
        "raw_dirs = raw_dirs.reshape((-1, raw_dirs.shape[-1]))\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names,\n",
        "        y=raw_dirs.norm(dim=-1),\n",
        "        mode=\"lines+markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map_light[\"neutral\"],\n",
        "        marker_size=8,\n",
        "        showlegend=False,\n",
        "    )\n",
        ")\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names[::],\n",
        "        y=raw_dirs.norm(dim=-1)[::],\n",
        "        mode=\"markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map[\"neutral\"],\n",
        "        marker_size=8,\n",
        "        showlegend=False,\n",
        "    )\n",
        ")\n",
        "\n",
        "print(layer_names[np.argmax(raw_dirs.norm(dim=-1)[:-1])])\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    # title=(\n",
        "    #     \"Statistics of refusal direction candidates at each layer\"\n",
        "    #     f\" layer for {MODEL_PATH}\"\n",
        "    # ),\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=28)),\n",
        "        dtick=4,\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=\"Norm of<br>Refusal Direction\", font=dict(size=28)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=300,\n",
        "    width=1000,\n",
        "    # width=20 + 12 * len(x_values),\n",
        "    margin=dict(l=20, r=20, t=20, b=20),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"norm_refusal.pdf\", scale=5)\n",
        "\n",
        "\n",
        "flatten_dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
        "pairwise_cosine = flatten_dirs @ flatten_dirs.T\n",
        "# pairwise_cosine = np.arccos(pairwise_cosine)\n",
        "mean_cosine = np.nanmean(pairwise_cosine, axis=-1) # For each position (out of 2*l), calculate the average cosine sim with everyone else\n",
        "\n",
        "fig = go.Figure()\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names,\n",
        "        y=mean_cosine,\n",
        "        mode=\"lines+markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map_light[\"neutral\"],\n",
        "        showlegend=False,\n",
        "        marker_size=8,\n",
        "    )\n",
        ")\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names[::],\n",
        "        y=mean_cosine[::],\n",
        "        mode=\"markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map[\"neutral\"],\n",
        "        showlegend=False,\n",
        "        marker_size=8,\n",
        "    )\n",
        ")\n",
        "\n",
        "# fig.add_trace(\n",
        "#     go.Scatter(\n",
        "#         x=layer_names,\n",
        "#         y=raw_dirs.norm(dim=-1) + mean_cosine / mean_cosine.max(),\n",
        "#         mode=\"lines+markers\",\n",
        "#         yaxis=\"y3\",\n",
        "#         marker_color=colour_map_light[\"neutral\"],\n",
        "#         showlegend=False\n",
        "#     )\n",
        "# )\n",
        "\n",
        "fig.update_layout(\n",
        "    # title=(\n",
        "    #     \"Statistics of refusal direction candidates at each extraction point\"\n",
        "    #     f\" for {MODEL_PATH}\"\n",
        "    # ),\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=28)),\n",
        "        dtick=4,\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=f\"Mean<br>Cosine Score\", font=dict(size=28)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=300,\n",
        "    width=1000,\n",
        "    # width=20 + 12 * len(x_values),\n",
        "    margin=dict(l=20, r=20, t=20, b=20),\n",
        ")\n",
        "\n",
        "fig.show()\n",
        "layer_names[np.nanargmax(mean_cosine)]\n",
        "\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"mean_cosine.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vgJYbu6yXxnU"
      },
      "outputs": [],
      "source": [
        "# sanity check\n",
        "\n",
        "# layers x resid_modules x tokens x batch x dim\n",
        "category2acts_normed = {\n",
        "    \"harmful\": harmful_acts_normed,\n",
        "    \"harmless\": harmless_acts_normed,\n",
        "}\n",
        "\n",
        "print(refusal_dirs.shape)\n",
        "x = category2acts_normed[\"harmful\"][..., chosen_token, :]\n",
        "print(x.shape)\n",
        "\n",
        "a = einops.einsum(\n",
        "    refusal_dirs, x, \"layer act dim, layer act batch dim -> layer act batch\"\n",
        ")\n",
        "# print(a[0][0] - refusal_dirs[0][0] @ x[0][0].T)\n",
        "assert np.allclose(a[0][0], refusal_dirs[0][0] @ x[0][0].T, atol=1e-6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dgFVw7SfXxnU"
      },
      "source": [
        "### Scalar projections of activations onto the local refusal direction at each extraction point\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kdsG-CnPXxnU"
      },
      "outputs": [],
      "source": [
        "# layers x resid_modules x tokens x batch x dim\n",
        "category2acts_normed = {\n",
        "    \"harmful\": harmful_acts_normed,\n",
        "    \"harmless\": harmless_acts_normed,\n",
        "}\n",
        "\n",
        "# x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
        "x_values = [str(i) for i in range(2 * num_layers)]\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "for category in categories:\n",
        "    acts_normed = category2acts_normed[category][:, :, :, chosen_token]\n",
        "    #  Projection each unit norm activation of each cat, for all posn, prompt onto respective posn-wise refusal direction\n",
        "    #  This is just cosine sim when both are unit vectors\n",
        "    projections = einops.einsum(\n",
        "        refusal_dirs,\n",
        "        acts_normed,\n",
        "        \"layer act dim, layer act batch dim -> layer act batch\",\n",
        "    )\n",
        "    projections = torch.tensor(projections)\n",
        "\n",
        "    #  Take Mean across layer\n",
        "    mean_projection = projections.mean(dim=-1)\n",
        "\n",
        "    y_values = mean_projection.flatten()\n",
        "\n",
        "    # mean\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values,\n",
        "            name=category,\n",
        "            mode=\"lines+markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=True,\n",
        "        )\n",
        "    )\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values,\n",
        "            name=category,\n",
        "            mode=\"lines+markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map_light[category], size=3),\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # variance\n",
        "    fig.add_trace(\n",
        "        variance_plot(\n",
        "            x=x_values,\n",
        "            y=projections.reshape(-1, projections.shape[-1]),\n",
        "            yaxis=\"y\",\n",
        "            fillcolor=colour_map_opaque[category],\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # dot markers\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values[1::],\n",
        "            y=y_values[1::],\n",
        "            name=f\"{category}\",\n",
        "            mode=\"markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    # title=f\"Scalar projections of activations onto the local refusal direction at each\n",
        "    # extraction point for {MODEL_PATH}\",\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        dtick=4,\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=\"Scalar Projections\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=250,\n",
        "    # width=20 + 12 * len(x_values),\n",
        "    width=600,\n",
        "    margin=dict(l=0, r=0, t=0, b=0),\n",
        "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"prj_onto_local_refusal_candidates.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gZach7SXXxnV"
      },
      "source": [
        "### Criteria for selecting the refusal direction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kcAPcA70XxnV"
      },
      "outputs": [],
      "source": [
        "raw_dirs.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2NiJ-MnwXxnV"
      },
      "outputs": [],
      "source": [
        "# Criteria: highest norm\n",
        "criteria = raw_dirs.norm(dim=-1)[:-1]\n",
        "\n",
        "argmax = np.nanargmax(criteria)\n",
        "max_norm_layer = argmax // 2\n",
        "max_norm_act_idx = argmax % 2\n",
        "\n",
        "print(\n",
        "    f\"Highest refusal direction norm at layer {max_norm_layer}, module\"\n",
        "    f\" {act_names[max_norm_act_idx]}, position {chosen_token}\"\n",
        ")\n",
        "\n",
        "# Criteria: High similiarity\n",
        "argmax = np.nanargmax(mean_cosine)\n",
        "max_mean_cosine_layer = argmax // 2\n",
        "max_mean_cosine_act_idx = argmax % 2\n",
        "\n",
        "print(\n",
        "    f\"Highest cosine similarity at layer {max_mean_cosine_layer}, module\"\n",
        "    f\" {act_names[max_mean_cosine_act_idx]}, position {chosen_token}\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sj0YJzozXxnV"
      },
      "source": [
        "### Selecting the refusal direction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rbaI_MP7XxnV"
      },
      "outputs": [],
      "source": [
        "# Layer with average highest cosine sim with everyone else\n",
        "\n",
        "chosen_layer = max_mean_cosine_layer\n",
        "chosen_act_idx = max_mean_cosine_act_idx\n",
        "chosen_token = -1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oc-2HvBtXxnV"
      },
      "source": [
        "### Projection of activation at each extraction poin onto the chosen refusal direction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lEfdKKDvXxnV"
      },
      "outputs": [],
      "source": [
        "fig = go.Figure()\n",
        "\n",
        "for category in [\"harmful\", \"harmless\"]:\n",
        "    if category == \"harmful\":\n",
        "        # acts_normed = harmful_acts.cpu().float().numpy()\n",
        "        acts_normed = harmful_acts_normed\n",
        "    else:\n",
        "        # acts_normed = harmless_acts.cpu().float().numpy()\n",
        "        acts_normed = harmless_acts_normed\n",
        "\n",
        "    # layers x resid_modules x batch_size x dim\n",
        "    activations = acts_normed[..., chosen_token, :].copy()\n",
        "\n",
        "    # dim\n",
        "    direction = refusal_dirs[chosen_layer, chosen_act_idx].copy()\n",
        "\n",
        "    # layers x resid_modules x batch_size\n",
        "    scalar_projections = einops.einsum(\n",
        "        activations,\n",
        "        direction,\n",
        "        \"... batch_size dim, ... dim -> ... batch_size\",\n",
        "    )\n",
        "    scalar_projections = np.nan_to_num(scalar_projections)\n",
        "    print(category)\n",
        "    print(scalar_projections.mean())\n",
        "    degrees = np.rad2deg(np.arccos(scalar_projections))\n",
        "\n",
        "    y_values = scalar_projections\n",
        "\n",
        "    batch_size = scalar_projections.shape[-1]\n",
        "\n",
        "    # x_values_flatten = sum(\n",
        "    #     [\n",
        "    #         [f\"{l}-mid\"] * batch_size + [f\"{l}-post\"] * batch_size\n",
        "    #         for l in range(num_layers)\n",
        "    #     ],\n",
        "    #     [],\n",
        "    # )\n",
        "    x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
        "    x_values = [str(i) for i in range(2 * num_layers)]\n",
        "\n",
        "    # variance\n",
        "    fig.add_trace(\n",
        "        variance_plot(\n",
        "            x=x_values,\n",
        "            y=torch.tensor(y_values).reshape(-1, degrees.shape[-1]),\n",
        "            yaxis=\"y\",\n",
        "            fillcolor=colour_map_opaque[category],\n",
        "            showlegend=False,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    # mean\n",
        "    ## for legend\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values.mean(axis=-1).flatten(),\n",
        "            mode=\"lines+markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=True,\n",
        "            name=category,\n",
        "        )\n",
        "    )\n",
        "    ## for lines\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values.mean(axis=-1).flatten(),\n",
        "            mode=\"lines\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map_light[category], size=3),\n",
        "            showlegend=False,\n",
        "            name=category,\n",
        "        )\n",
        "    )\n",
        "    ## for markers\n",
        "    fig.add_trace(\n",
        "        go.Scatter(\n",
        "            x=x_values,\n",
        "            y=y_values.mean(axis=-1).flatten(),\n",
        "            mode=\"markers\",\n",
        "            yaxis=\"y\",\n",
        "            marker=dict(color=colour_map[category], size=3),\n",
        "            showlegend=False,\n",
        "            name=category,\n",
        "        )\n",
        "    )\n",
        "\n",
        "    activations -= 2 * einops.einsum(\n",
        "        np.maximum(scalar_projections, 0),\n",
        "        direction,\n",
        "        \"layer resid_module batch_size, dim -> layer resid_module batch_size dim\",\n",
        "    )\n",
        "    scalar_projections = einops.einsum(\n",
        "        activations,\n",
        "        direction,\n",
        "        \"... batch_size dim, ... dim -> ... batch_size\",\n",
        "    )\n",
        "    print(category)\n",
        "    print(scalar_projections.mean())\n",
        "    degrees = np.rad2deg(np.arccos(scalar_projections))\n",
        "\n",
        "    y_values = scalar_projections\n",
        "\n",
        "\n",
        "module_names = [\"mid\", \"post\"]\n",
        "fig.update_layout(\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    # yaxis=dict(tickformat=\".2E\"),\n",
        "    plot_bgcolor=\"white\",\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        dtick=4,\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=\"Scalar Projections\", font=dict(size=20)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=18),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=250,\n",
        "    width=600,\n",
        "    # title=(\n",
        "    #     \"Scalar projections of activations at each layer onto the chosen refusal direction\"\n",
        "    #     f\" ({chosen_layer}-{module_names[chosen_act_idx]})\"\n",
        "    # ),\n",
        "    # yaxis=dict(matches=None),\n",
        "    margin=dict(l=20, r=20, t=20, b=20),\n",
        "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"prj_onto_refusal_dir.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bn8y0F_3XxnV"
      },
      "outputs": [],
      "source": [
        "# sanity check\n",
        "\n",
        "print(harmful_acts.shape)\n",
        "a = harmful_acts[chosen_layer, chosen_act_idx, 0, chosen_token]\n",
        "an = a / a.norm()\n",
        "an = an.cpu().float()\n",
        "b = harmful_acts_normed[chosen_layer, chosen_act_idx, 0, chosen_token].copy()\n",
        "\n",
        "print(an.dtype)\n",
        "print(b.dtype)\n",
        "\n",
        "print(an, np.linalg.norm(an))\n",
        "print(b, np.linalg.norm(b))\n",
        "np.testing.assert_allclose(an, b, rtol=10e-6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jod7uW3vXxnV"
      },
      "source": [
        "### Scalar projections of weights at each SelfAttn and MLP layer onto the refusal direction\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P3UbDEr9XxnV"
      },
      "outputs": [],
      "source": [
        "layer_names = sum([[f\"{i}-mid\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
        "direction = refusal_dirs[chosen_layer, chosen_act_idx].copy()\n",
        "direction /= np.linalg.norm(direction)\n",
        "\n",
        "x_values = []\n",
        "prj_values = []\n",
        "sum_magnitude = []\n",
        "\n",
        "for i in range(num_layers):\n",
        "    W = model.blocks[i].attn.W_O\n",
        "    prjs = W.detach().cpu().float().numpy() @ direction\n",
        "    prjs = prjs.flatten()\n",
        "    prj_values.append(prjs)\n",
        "    x_values.extend([layer_names[i * 2]] * prjs.shape[0])\n",
        "    sum_magnitude.append(np.sum(np.abs(prjs)))\n",
        "\n",
        "    W = model.blocks[i].mlp.W_out\n",
        "    prjs = W.detach().cpu().float().numpy() @ direction\n",
        "    prjs = prjs.flatten()\n",
        "    prj_values.append(prjs)\n",
        "    x_values.extend([layer_names[i * 2 + 1]] * prjs.shape[0])\n",
        "    sum_magnitude.append(np.sum(np.abs(prjs)))\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "fig.add_trace(\n",
        "    go.Box(\n",
        "        x=x_values,\n",
        "        y=np.hstack(prj_values),\n",
        "        boxmean=True,\n",
        "        marker_color=px.colors.qualitative.Plotly[3],\n",
        "    )\n",
        ")\n",
        "fig.update_layout(\n",
        "    title=(\n",
        "        \"Scalar projections of weights at each layer onto the refusal direction\"\n",
        "        f\" ({chosen_layer}-{module_names[chosen_act_idx]})\"\n",
        "    ),\n",
        "    yaxis_range=[-0.5, 0.5],\n",
        ")\n",
        "\n",
        "fig.show()\n",
        "\n",
        "fig = px.line(x=layer_names, y=sum_magnitude, markers=True)\n",
        "fig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2EoxY5i1CWe3"
      },
      "source": [
        "## Constructing the steering plane\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nWzk7_bFXxnW"
      },
      "source": [
        "### Compute the PCA from candidate directions\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d4tkKSq6XxnW"
      },
      "outputs": [],
      "source": [
        "refusal_dirs_flatten = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
        "refusal_dirs_flatten.shape\n",
        "\n",
        "from sklearn.decomposition import PCA\n",
        "\n",
        "pca_model = PCA().fit(refusal_dirs_flatten)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FqZSjKCMXxnW"
      },
      "outputs": [],
      "source": [
        "vars = pca_model.explained_variance_ratio_\n",
        "vars"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AsnWlrCdXxnW"
      },
      "outputs": [],
      "source": [
        "components = pca_model.components_\n",
        "\n",
        "print(refusal_dirs_flatten.shape)\n",
        "print(components.shape)\n",
        "np.degrees(np.arccos(refusal_dirs_flatten @ components[-1]))\n",
        "\n",
        "# another adhoc attempt to find the best direction is to take the mean of the candidates\n",
        "mean_d = refusal_dirs_flatten.mean(axis=0)\n",
        "mean_d /= np.linalg.norm(mean_d)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "thAEJ9CxXxnW"
      },
      "outputs": [],
      "source": [
        "# angle between each component and the chosen direction\n",
        "print(chosen_layer, chosen_act_idx)\n",
        "np.degrees(np.arccos(components @ refusal_dirs[chosen_layer][chosen_act_idx]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aYZYGNhGXxnW"
      },
      "source": [
        "### Visualize of the candidate directions on the steering plane\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "hrCl-za6XxnW"
      },
      "outputs": [],
      "source": [
        "# first basis is the chosen direction (in this case, the one with the highest similarity)\n",
        "print(max_mean_cosine_layer, max_mean_cosine_act_idx)\n",
        "u1 = refusal_dirs[max_mean_cosine_layer][max_mean_cosine_act_idx].copy()\n",
        "\n",
        "# second basis is the first principal component\n",
        "u2 = components[0].copy()\n",
        "\n",
        "b1 = u1 / np.linalg.norm(u1)\n",
        "b2 = u2 - (u2 @ b1) * b1\n",
        "b2 /= np.linalg.norm(b2)\n",
        "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
        "\n",
        "prj_matrix = np.column_stack([b1, b2])\n",
        "refusal_dirs_mapped = refusal_dirs_flatten @ prj_matrix\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "norms = np.linalg.norm(refusal_dirs_mapped)\n",
        "x = refusal_dirs_mapped[:, 0] / norms\n",
        "y = refusal_dirs_mapped[:, 1] / norms\n",
        "angle = np.arctan2(y, x)\n",
        "\n",
        "for point, label in zip(\n",
        "    [u1 @ prj_matrix, u2 @ prj_matrix], [\"chosen<br>direction\", \"1st PC\"]\n",
        "):\n",
        "    fig.add_annotation(\n",
        "        # hovertext=str(i),\n",
        "        ax=0,\n",
        "        ay=0,\n",
        "        x=point[0],\n",
        "        y=point[1],\n",
        "        axref=\"x\",\n",
        "        ayref=\"y\",\n",
        "        showarrow=True,\n",
        "        arrowhead=2,\n",
        "        arrowwidth=2,\n",
        "        xanchor=\"right\",\n",
        "        yanchor=\"top\",\n",
        "        opacity=0.5,\n",
        "    )\n",
        "    fig.add_annotation(\n",
        "        x=point[0],\n",
        "        y=point[1],\n",
        "        text=label,\n",
        "        font=dict(size=22),\n",
        "        showarrow=False,\n",
        "        yshift=30,\n",
        "        xshift=20,\n",
        "    )\n",
        "\n",
        "points = go.Scatter(\n",
        "    x=refusal_dirs_mapped[:, 0],\n",
        "    y=refusal_dirs_mapped[:, 1],\n",
        "    text=[str(i) for i in range(len(refusal_dirs_mapped))],\n",
        "    mode=\"markers\",\n",
        "    marker=dict(\n",
        "        symbol=\"arrow\",\n",
        "        angle=90 - np.degrees(angle),\n",
        "        # size=[i + 7 if not np.isnan(i) else 0 for i in norms.flatten()],\n",
        "        size=20,\n",
        "        # color=np.linspace(0, refusal_dirs_mapped.shape[0], refusal_dirs_mapped.shape[0]),\n",
        "        color=[i for i in range(refusal_dirs_mapped.shape[0])],\n",
        "        # color=color.flatten(),\n",
        "        # colorscale=\"Phase\",\n",
        "        showscale=True,\n",
        "    ),\n",
        "    name=\"layers\",\n",
        "    showlegend=True,\n",
        ")\n",
        "fig.add_trace(points)\n",
        "\n",
        "fig.add_annotation(\n",
        "    xref=\"paper\",\n",
        "    yref=\"paper\",\n",
        "    text=\"Extraction<br>Point\",\n",
        "    font=dict(size=22),\n",
        "    showarrow=False,\n",
        "    x=1.17,\n",
        "    y=-0.15,\n",
        "    # yshift=20,\n",
        "    # xshift=20\n",
        ")\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    # plot_bgcolor=\"white\",\n",
        "    autosize=False,\n",
        "    # width=800,\n",
        "    height=600,\n",
        "    # yaxis_range=[-1.0, 1.0],\n",
        "    # xaxis_range=[-0.4, 1.4],\n",
        "    yaxis_scaleanchor=\"x\",\n",
        "    yaxis_scaleratio=1,\n",
        "    xaxis_dtick=0.5,\n",
        "    yaxis_dtick=0.5,\n",
        "    font=dict(size=22),\n",
        "    margin=dict(l=0, r=100, t=0, b=75),\n",
        "    legend=dict(visible=False),\n",
        "    # xaxis=dict(gridcolor=\"grey\"),\n",
        "    # yaxis=dict(gridcolor=\"grey\"),\n",
        ")\n",
        "\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"steering_plane.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fkSIVzbdXxnW"
      },
      "source": [
        "## Steering by rotation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "67MMG0jtXxnW"
      },
      "source": [
        "### Rotation utils\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "knlLLzOdXxnW"
      },
      "outputs": [],
      "source": [
        "def get_rotation_matrix(degree, basis1, basis2):\n",
        "    assert len(basis1.shape) == 1\n",
        "    assert len(basis2.shape) == 1\n",
        "    assert basis1.shape == basis2.shape\n",
        "\n",
        "    n = basis1.shape[-1]\n",
        "\n",
        "    if degree % 360 == 0:\n",
        "        return np.eye(n)\n",
        "\n",
        "    # ensure bases are orthonormal\n",
        "    u = basis1 / np.linalg.norm(basis1)\n",
        "    v = basis2 - (basis2 @ u) * u\n",
        "    v /= np.linalg.norm(v)\n",
        "\n",
        "    theta = np.deg2rad(degree)\n",
        "    cos_theta = np.cos(theta)\n",
        "    sin_theta = np.sin(theta)\n",
        "    # print(cos_theta, sin_theta)\n",
        "\n",
        "    # rotate counter-clockwise\n",
        "    R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
        "\n",
        "    uv = np.column_stack([u, v])\n",
        "    R = np.eye(n) - (np.outer(u, u) + np.outer(v, v)) + uv @ R_theta @ uv.T\n",
        "\n",
        "    return R\n",
        "\n",
        "\n",
        "# sanity check\n",
        "print(chosen_layer, chosen_act_idx)\n",
        "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
        "b1 = components[-1].copy()\n",
        "b2 = mean_d.copy()\n",
        "\n",
        "b1 = b1 / np.linalg.norm(b1)\n",
        "b2 = b2 - (b2 @ b1) * b1\n",
        "b2 /= np.linalg.norm(b2)\n",
        "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
        "\n",
        "deg = np.rad2deg(np.arccos(b1 @ b2))\n",
        "print(deg)\n",
        "\n",
        "R = get_rotation_matrix(30, b1, b2)\n",
        "\n",
        "u = P @ R @ d\n",
        "u /= np.linalg.norm(u)\n",
        "v = P @ d\n",
        "v /= np.linalg.norm(v)\n",
        "\n",
        "print(np.rad2deg(np.arccos((R @ d) @ d)))\n",
        "print(np.rad2deg(np.arccos(u @ v)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZaCFjke9XxnW"
      },
      "outputs": [],
      "source": [
        "def rotate_to_target(x, target_degree, basis1, basis2):\n",
        "    assert len(basis1.shape) == 1\n",
        "    assert len(basis2.shape) == 1\n",
        "    assert basis1.shape == basis2.shape\n",
        "\n",
        "    n = basis1.shape[-1]\n",
        "\n",
        "    # ensure bases are orthonormal\n",
        "    u = basis1 / np.linalg.norm(basis1)\n",
        "    v = basis2 - (basis2 @ u) * u\n",
        "    v /= np.linalg.norm(v)\n",
        "\n",
        "    theta = np.deg2rad(target_degree)\n",
        "    cos_theta = np.cos(theta)\n",
        "    sin_theta = np.sin(theta)\n",
        "\n",
        "    P = np.outer(u, u) + np.outer(v, v)\n",
        "\n",
        "    # rotate counter-clockwise\n",
        "    R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
        "\n",
        "    uv = np.column_stack([u, v])\n",
        "\n",
        "    rotated_component = uv @ R_theta @ np.array([1, 0])\n",
        "    Px = x @ P\n",
        "    scale = np.linalg.norm(Px, axis=-1, keepdims=True)\n",
        "\n",
        "    result = x - Px + scale * rotated_component\n",
        "\n",
        "    return result\n",
        "\n",
        "\n",
        "# sanity check\n",
        "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
        "b1 = components[-1].copy()\n",
        "b2 = mean_d.copy()\n",
        "b1 = b1 / np.linalg.norm(b1)\n",
        "b2 = b2 - (b2 @ b1) * b1\n",
        "b2 /= np.linalg.norm(b2)\n",
        "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
        "\n",
        "rd = rotate_to_target(d, 60, b1, b2)\n",
        "\n",
        "u = P @ d\n",
        "u /= np.linalg.norm(u)\n",
        "deg = np.rad2deg(np.arccos(u @ b1))\n",
        "R = get_rotation_matrix(60 - deg, b1, b2)\n",
        "print(R @ d)\n",
        "print(rd)\n",
        "\n",
        "print(d.shape)\n",
        "print(np.linalg.norm(d, axis=-1, keepdims=True) * np.array([1, 2]))\n",
        "2 * np.array([1, 2])\n",
        "rotate_to_target(np.random.rand(5, 8, d.shape[0]), 60, b1, b2).shape\n",
        "print(chosen_layer, chosen_act_idx)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tRSlNKYvXxnX"
      },
      "source": [
        "### Statistics of candidate directions on the steering plane\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ry8De3ePXxnX"
      },
      "outputs": [],
      "source": [
        "print(chosen_layer, chosen_act_idx)\n",
        "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
        "b1 = mean_d.copy()\n",
        "b2 = components[0].copy()\n",
        "\n",
        "b1 = b1 / np.linalg.norm(b1)\n",
        "b2 = b2 - (b2 @ b1) * b1\n",
        "b2 /= np.linalg.norm(b2)\n",
        "\n",
        "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
        "\n",
        "print(mean_d)\n",
        "print(mean_d @ P)\n",
        "\n",
        "proj = refusal_dirs_flatten @ P\n",
        "proj_norm = np.linalg.norm(proj, axis=-1)\n",
        "\n",
        "fig = px.line(x=layer_names, y=proj_norm)\n",
        "fig.update_layout(\n",
        "    title=\"Norm of projections of candidate directions on the steering plane\",\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "proj_normed = proj / proj_norm[:, None]\n",
        "proj_angle = np.rad2deg(np.arccos(proj_normed @ b1))\n",
        "\n",
        "fig = px.line(x=layer_names, y=proj_angle)\n",
        "fig.update_layout(\n",
        "    title=(\n",
        "        \"Angles between the projections of candiate directions and the chosen\"\n",
        "        f\" direction on the steering plane ({chosen_layer}-{chosen_act_idx})\"\n",
        "    ),\n",
        ")\n",
        "fig.show()\n",
        "print(chosen_layer, chosen_act_idx)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dHsPHecQXxnX"
      },
      "source": [
        "### Creating the steering config\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "mIk6ixKfXxnX"
      },
      "outputs": [],
      "source": [
        "from pprint import pprint\n",
        "\n",
        "from transformers import AutoModel\n",
        "from accelerate import init_empty_weights\n",
        "\n",
        "with init_empty_weights():\n",
        "    hf_model = AutoModel.from_pretrained(MODEL_PATH, cache_dir = MODEL_CACHE_DIR)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "scrolled": true,
        "id": "bKjLclhYXxnX"
      },
      "outputs": [],
      "source": [
        "# Let's look at all the names of the modules in the model so that we know which ones to\n",
        "# apply steering to (we're interested in the normalization layers before each MLP and\n",
        "# attention layer)\n",
        "pprint(list(n for n, m in hf_model.named_modules()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5SE55beWXxnX"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "from pprint import pprint\n",
        "\n",
        "target_modules = [\"mid\", \"post\"]\n",
        "# because in the model's architecture, there are no explicit modules for the residual stream (pre/mid/post)\n",
        "# thus let's use the inputs of layernorm modules as equivalent\n",
        "# resid-pre = input of input_layernorm\n",
        "# resid-mid = input of post_attention_layernorm / pre_feedforward_layernorm (gemma)\n",
        "# resid-post = input of the next input_layernorm\n",
        "layernorm_modules = [\"input_layernorm\", \"post_attention_layernorm\"]\n",
        "if \"gemma\" in MODEL_PATH:\n",
        "    layernorm_modules += [\"post_attention_layernorm\", \"post_feedforward_layernorm\"]\n",
        "\n",
        "\n",
        "mean_d = refusal_dirs_flatten.mean(axis=0)\n",
        "mean_d /= np.linalg.norm(mean_d)\n",
        "\n",
        "print(chosen_layer, chosen_act_idx)\n",
        "\n",
        "# saving various steering configs\n",
        "for first_direction, first_dir_name in [\n",
        "    (\n",
        "        refusal_dirs[max_norm_layer][max_norm_act_idx].copy(),\n",
        "        f\"dir_max_norm_{max_norm_layer}_{target_modules[max_norm_act_idx]}\",\n",
        "    ),\n",
        "    (\n",
        "        refusal_dirs[max_mean_cosine_layer][max_mean_cosine_act_idx].copy(),\n",
        "        f\"dir_max_sim_{max_mean_cosine_layer}_{target_modules[max_mean_cosine_act_idx]}\",\n",
        "    ),\n",
        "    (mean_d.copy(), \"dir_mean\"),\n",
        "]:\n",
        "\n",
        "    second_direction = components[0].copy()\n",
        "\n",
        "    num_layers = refusal_dirs.shape[0]\n",
        "    steering_config = {}\n",
        "    for layer_idx in range(num_layers):\n",
        "        for module in layernorm_modules:\n",
        "            if module != \"input_layernorm\":\n",
        "                module_name = f\"model.layers.{layer_idx}.{module}\"\n",
        "            elif layer_idx < num_layers - 1:\n",
        "                module_name = f\"model.layers.{layer_idx + 1}.{module}\"\n",
        "            else:\n",
        "                continue\n",
        "\n",
        "            steering_config[module_name] = {\n",
        "                \"mode\": \"rotate_to\",\n",
        "                \"first_direction\": first_direction,\n",
        "                \"second_direction\": second_direction,\n",
        "            }\n",
        "\n",
        "    # pprint(steering_config)\n",
        "    output_name = f\"steering_config-en-{first_dir_name}-pca_0.npy\"\n",
        "    print(output_name)\n",
        "    np.save(OUTPUT_DIR / output_name, steering_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZEcD28fNXxnX"
      },
      "source": [
        "### Testing Angular Steering\n",
        "\n",
        "(this takes a while because inference with transformers-lens is slow)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hAbNr-zoXxnX"
      },
      "source": [
        "#### Try with your own prompt\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YCIWO_mLXxnX"
      },
      "outputs": [],
      "source": [
        "# transformation_func = get_rotate_to_target_func(\n",
        "#     target_degree=180, basis1=refusal_dir.copy(), basis2=components[0].copy()\n",
        "# )\n",
        "\n",
        "# fwd_hooks = [\n",
        "#     (\n",
        "#         utils.get_act_name(act_name, layer),\n",
        "#         functools.partial(\n",
        "#             activation_rotation_hook,\n",
        "#             transformation_func=transformation_func,\n",
        "#         ),\n",
        "#     )\n",
        "#     for layer in intervention_layers\n",
        "#     for act_idx, act_name in enumerate([\"resid_pre\", \"resid_mid\"])\n",
        "# ]\n",
        "\n",
        "# answer = get_generations(\n",
        "#     model,\n",
        "#     [\"how to pick a lock\"],\n",
        "#     model.tokenizer,\n",
        "#     fwd_hooks=fwd_hooks,\n",
        "#     max_tokens_generated=256,\n",
        "# )\n",
        "\n",
        "# for a in answer:\n",
        "#     print(a.split(model.tokenizer.eos_token)[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4apLw6o5XxnX"
      },
      "source": [
        "#### Steered generation on test set at every 10 steering degree\n",
        "\n",
        "(very very slow)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "jupyter": {
          "outputs_hidden": true
        },
        "id": "90tHDl4MXxnX"
      },
      "outputs": [],
      "source": [
        "# MAX_GENERATION_TOKENS = 512\n",
        "# NUM_TEST_SAMPLES = 4\n",
        "# test_samples = harmful_inst_test[:NUM_TEST_SAMPLES]\n",
        "\n",
        "\n",
        "# def get_rotate_to_target_func(target_degree, basis1, basis2):\n",
        "#     assert len(basis1.shape) == 1\n",
        "#     assert len(basis2.shape) == 1\n",
        "#     assert basis1.shape == basis2.shape\n",
        "\n",
        "#     n = basis1.shape[-1]\n",
        "\n",
        "#     # ensure bases are orthonormal\n",
        "#     u = basis1 / np.linalg.norm(basis1)\n",
        "#     v = basis2 - (basis2 @ u) * u\n",
        "#     v /= np.linalg.norm(v)\n",
        "\n",
        "#     theta = np.deg2rad(target_degree)\n",
        "#     cos_theta = np.cos(theta)\n",
        "#     sin_theta = np.sin(theta)\n",
        "\n",
        "#     P = np.outer(u, u) + np.outer(v, v)\n",
        "\n",
        "#     # rotate counter-clockwise\n",
        "#     R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
        "\n",
        "#     uv = np.column_stack([u, v])\n",
        "\n",
        "#     rotated_component = uv @ R_theta @ np.array([1, 0])\n",
        "\n",
        "#     def __func(x: Tensor):\n",
        "#         Px = x @ torch.tensor(P, device=x.device, dtype=x.dtype)\n",
        "#         scale = Px.norm(dim=-1, keepdim=True)\n",
        "\n",
        "#         result = (\n",
        "#             x\n",
        "#             - Px\n",
        "#             + scale * torch.tensor(rotated_component, device=x.device, dtype=x.dtype)\n",
        "#         )\n",
        "\n",
        "#         return result\n",
        "\n",
        "#     return __func\n",
        "\n",
        "\n",
        "# def activation_rotation_hook(\n",
        "#     activation: Float[Tensor, \"... d_act\"],\n",
        "#     hook: HookPoint,\n",
        "#     transformation_func,\n",
        "# ):\n",
        "#     return transformation_func(activation)\n",
        "\n",
        "\n",
        "# if not \"baseline_generations\" in locals():\n",
        "#     # if True:\n",
        "#     baseline_generations = get_generations(\n",
        "#         model,\n",
        "#         test_samples,\n",
        "#         model.tokenizer,\n",
        "#         fwd_hooks=[],\n",
        "#         max_tokens_generated=MAX_GENERATION_TOKENS,\n",
        "#     )\n",
        "\n",
        "# intervention_generations = {}\n",
        "# refusal_dir = refusal_dirs[chosen_layer][chosen_act_idx]\n",
        "# for degree in range(0, 360, 10):\n",
        "#     intervention_layers = list(range(model.cfg.n_layers))\n",
        "#     print(\"degree\", degree)\n",
        "#     print(intervention_layers)\n",
        "\n",
        "#     if degree in intervention_generations:\n",
        "#         continue\n",
        "\n",
        "#     transformation_func = get_rotate_to_target_func(\n",
        "#         target_degree=degree, basis1=refusal_dir.copy(), basis2=components[0].copy()\n",
        "#     )\n",
        "\n",
        "#     fwd_hooks = [\n",
        "#         (\n",
        "#             utils.get_act_name(act_name, layer),\n",
        "#             functools.partial(\n",
        "#                 activation_rotation_hook,\n",
        "#                 transformation_func=transformation_func,\n",
        "#             ),\n",
        "#         )\n",
        "#         for layer in intervention_layers\n",
        "#         for act_idx, act_name in enumerate([\"resid_pre\", \"resid_mid\"])\n",
        "#     ]\n",
        "\n",
        "#     intervention_generations[degree] = get_generations(\n",
        "#         model,\n",
        "#         test_samples,\n",
        "#         model.tokenizer,\n",
        "#         fwd_hooks=fwd_hooks,\n",
        "#         max_tokens_generated=MAX_GENERATION_TOKENS,\n",
        "#     )\n",
        "\n",
        "# for i in range(NUM_TEST_SAMPLES):\n",
        "#     print(f\"INSTRUCTION {i}: {repr(test_samples[i])}\")\n",
        "#     print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
        "#     print(\n",
        "#         textwrap.fill(\n",
        "#             baseline_generations[i],\n",
        "#             width=100,\n",
        "#             initial_indent=\"\\t\",\n",
        "#             subsequent_indent=\"\\t\",\n",
        "#         )\n",
        "#     )\n",
        "#     print(Fore.RESET)\n",
        "#     for degree in sorted(intervention_generations.keys()):\n",
        "#         print(Fore.RED + f\"INTERVENTION COMPLETION (degree {degree}):\")\n",
        "#         print(intervention_generations[degree][i].split(model.tokenizer.eos_token)[0])\n",
        "#         # print(\n",
        "#         #     textwrap.fill(\n",
        "#         #         intervention_generations[extraction_layer][i],\n",
        "#         #         width=100,\n",
        "#         #         initial_indent=\"\\t\",\n",
        "#         #         subsequent_indent=\"\\t\",\n",
        "#         #     )\n",
        "#         # )\n",
        "#         print(Fore.RESET)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "momentum_steering_v2 (3.10.12)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}