{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c9c1ae92",
      "metadata": {
        "id": "c9c1ae92"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import functools\n",
        "import einops\n",
        "import requests\n",
        "import pandas as pd\n",
        "import io\n",
        "import textwrap\n",
        "import gc\n",
        "import numpy as np\n",
        "import plotly\n",
        "\n",
        "\n",
        "from pathlib import Path\n",
        "from datasets import load_dataset\n",
        "from sklearn.model_selection import train_test_split\n",
        "from tqdm import tqdm\n",
        "from torch import Tensor\n",
        "from typing import List\n",
        "from transformer_lens import HookedTransformer, HookedTransformerConfig ,utils, ActivationCache\n",
        "from transformer_lens.hook_points import HookPoint\n",
        "from transformers import AutoTokenizer\n",
        "from jaxtyping import Float, Int\n",
        "from colorama import Fore\n",
        "import plotly.graph_objects as go\n",
        "import plotly.express as px"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b2cd20ef",
      "metadata": {
        "id": "b2cd20ef"
      },
      "source": [
        "### Initialize a Completely new Pretrained Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "57b1c908",
      "metadata": {
        "id": "57b1c908"
      },
      "outputs": [],
      "source": [
        "# Choose one model for the experimentw\n",
        "MODEL_PATH = (\n",
        "    # \"Qwen/Qwen2.5-3B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-7B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-14B-Instruct\"\n",
        "    # \"Qwen/Qwen2.5-32B-Instruct\"\n",
        "    # \"meta-llama/Llama-3.2-3B-Instruct\"\n",
        "    \"meta-llama/Llama-3.1-8B-Instruct\"\n",
        "    # \"google/gemma-2-9b-it\"\n",
        "    # \"google/gemma-2-27b-it\"\n",
        "    # \"Unispac/Gemma-2-9B-IT-With-Deeper-Safety-Alignment\"\n",
        ")\n",
        "\n",
        "MODEL_NAME = MODEL_PATH.split(\"/\")[-1]\n",
        "\n",
        "#  We use GPU 1 from {0, ..., 7}\n",
        "DEVICE = \"cuda:0\"\n",
        "\n",
        "BATCH_SIZE = 16\n",
        "beta = 0.9\n",
        "\n",
        "OUTPUT_PARENT_DIR = Path(\"output\")\n",
        "\n",
        "OUTPUT_DIR = OUTPUT_PARENT_DIR\n",
        "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "VISUALIZATION_PARENT_DIR = Path(\"visualization\")\n",
        "\n",
        "VISUALIZATION_DIR = VISUALIZATION_PARENT_DIR\n",
        "VISUALIZATION_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES\n",
        "\n",
        "if MODEL_PATH not in OFFICIAL_MODEL_NAMES:\n",
        "    OFFICIAL_MODEL_NAMES.append(MODEL_PATH)\n",
        "\n",
        "CACHE_DIR = Path(\"/root/.cache/huggingface\")\n",
        "MODEL_CACHE_DIR = CACHE_DIR / \"hub\"\n",
        "DATASETS_CACHE_DIR = CACHE_DIR / \"datasets\"\n",
        "\n",
        "cfg = HookedTransformerConfig(n_layers = 150,\n",
        "                              d_model = 768,\n",
        "                              d_head = 128,\n",
        "                              n_ctx = 2048,\n",
        "                              d_mlp = 1532,\n",
        "                              d_vocab = 128256,\n",
        "                              act_fn = \"silu\",\n",
        "                              normalization_type = \"RMS\",\n",
        "                              device = DEVICE,\n",
        "                              attn_only=False)\n",
        "\n",
        "model = HookedTransformer(cfg)\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.1-8B-Instruct\")\n",
        "model.tokenizer = tokenizer\n",
        "\n",
        "model.tokenizer.padding_side = \"left\"\n",
        "\n",
        "# store original chat template\n",
        "ORIGINAL_CHAT_TEMPLATE = model.tokenizer.chat_template\n",
        "\n",
        "model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8c48d7a5",
      "metadata": {
        "id": "8c48d7a5"
      },
      "outputs": [],
      "source": [
        "# handle pad token for some model\n",
        "if not model.tokenizer.pad_token:\n",
        "    if \"qwen1\" in MODEL_PATH.lower():\n",
        "        model.tokenizer.pad_token = \"<|endoftext|>\"\n",
        "    elif model.tokenizer.eos_token:\n",
        "        model.tokenizer.pad_token = model.tokenizer.eos_token\n",
        "    else:\n",
        "        raise ValueError(\"No pad token found in the tokenizer.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "020324fe",
      "metadata": {
        "id": "020324fe"
      },
      "outputs": [],
      "source": [
        "def get_harmful_instructions():\n",
        "    url = \"https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv\"\n",
        "    response = requests.get(url)\n",
        "\n",
        "    dataset = pd.read_csv(io.StringIO(response.content.decode(\"utf-8\")))\n",
        "    instructions = dataset[\"goal\"].tolist()\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmless_instructions():\n",
        "\n",
        "    hf_path = \"tatsu-lab/alpaca\"\n",
        "    dataset = load_dataset(hf_path, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    # filter for instructions that do not have inputs\n",
        "    instructions = []\n",
        "    for i in range(len(dataset[\"train\"])):\n",
        "        if dataset[\"train\"][i][\"input\"].strip() == \"\":\n",
        "            instructions.append(dataset[\"train\"][i][\"instruction\"])\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmful_instructions_jp():\n",
        "    huggingface_id = \"kunishou/do-not-answer-ja\"\n",
        "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    selected_types = [\n",
        "        # # Information Hazards\n",
        "        # \"Risks from leaking or inferring sensitive information (organization/gov)\",\n",
        "        # (\n",
        "        #     \"Compromise privacy by leaking or inferring private information\"\n",
        "        #     \" (person/individual)\"\n",
        "        # ),\n",
        "        # Malicious Uses\n",
        "        \"Assisting illegal activities\",\n",
        "        \"Nudging or advising users to perform unethical or unsafe actions\",\n",
        "        \"Reducing the cost of disinformation campaigns\",\n",
        "        # Discrimination, Exclusion, Toxicity, Hateful, Offensive\n",
        "        \"Social stereotypes and unfair discrimination\",\n",
        "        \"Toxic language (hate speech)\",\n",
        "        \"Adult Content\",\n",
        "        # # Misinformation Harms\n",
        "        # \"Causing material harm by disseminating misinformation e.g. in medicine or law\",\n",
        "        # \"Disseminating false or misleading information\",\n",
        "        # # Human-Chatbot Interaction Harms\n",
        "        # \"Mental Health or Overreliance Crisis\",\n",
        "        # \"Treat Chatbot as a Human\",\n",
        "    ]\n",
        "\n",
        "    instructions = []\n",
        "    for item in dataset[\"train\"]:\n",
        "        if item[\"types_of_harm\"] not in selected_types:\n",
        "            continue\n",
        "        instructions.append(item[\"question\"])\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test\n",
        "\n",
        "\n",
        "def get_harmless_instructions_jp():\n",
        "    huggingface_id = \"Lazycuber/alpaca-jp\"\n",
        "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
        "\n",
        "    # filter for instructions that do not have inputs\n",
        "    instructions = []\n",
        "    for item in dataset[\"train\"]:\n",
        "        if item[\"input\"].strip() != \"\":\n",
        "            continue\n",
        "        inst = item[\"instruction\"]\n",
        "        inst = inst.strip(\"「」'\")\n",
        "        instructions.append(inst)\n",
        "\n",
        "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
        "    return train, test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "fd10ad68",
      "metadata": {
        "id": "fd10ad68"
      },
      "outputs": [],
      "source": [
        "LANGUAGE = \"en\"\n",
        "\n",
        "if LANGUAGE == \"en\":\n",
        "    harmful_inst_train, harmful_inst_test = get_harmful_instructions()\n",
        "    harmless_inst_train, harmless_inst_test = get_harmless_instructions()\n",
        "elif LANGUAGE == \"jp\":\n",
        "    harmful_inst_train, harmful_inst_test = get_harmful_instructions_jp()\n",
        "    harmless_inst_train, harmless_inst_test = get_harmless_instructions_jp()\n",
        "\n",
        "print(f\"Train: {len(harmful_inst_train)} harmful, {len(harmless_inst_train)} harmless\")\n",
        "print(f\"Test: {len(harmful_inst_test)} harmful, {len(harmless_inst_test)} harmless\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "696edc99",
      "metadata": {
        "id": "696edc99"
      },
      "outputs": [],
      "source": [
        "print(\"Harmful instructions:\")\n",
        "for i in range(4):\n",
        "    print(f\"\\t{harmful_inst_train[i]}\")\n",
        "print(\"Harmless instructions:\")\n",
        "for i in range(4):\n",
        "    print(f\"\\t{harmless_inst_train[i]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1f3b2971",
      "metadata": {
        "id": "1f3b2971"
      },
      "outputs": [],
      "source": [
        "#  Function Generates the (B, L) inputs to the model w.r.t tokenizer of desired model\n",
        "def instructions_to_chat_tokens(\n",
        "    tokenizer: AutoTokenizer,\n",
        "    instructions: List[str],\n",
        ") -> Int[Tensor, \"batch_size seq_len\"]:\n",
        "    #  Checks for if there is a required chat template\n",
        "    if tokenizer.chat_template:\n",
        "        #  This automatically creates the Batch\n",
        "        convos = [\n",
        "            [{\"role\": \"user\", \"content\": instruction}] for instruction in instructions\n",
        "        ]\n",
        "        return tokenizer.apply_chat_template(\n",
        "            convos,\n",
        "            padding=True, # Padding here ensures all prompts are of the same length\n",
        "            truncation=False,\n",
        "            add_generation_prompt=True,\n",
        "            return_tensors=\"pt\",\n",
        "        )\n",
        "    else:\n",
        "        return tokenizer(\n",
        "            instructions, padding=True, truncation=False, return_tensors=\"pt\"\n",
        "        ).input_ids"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1e7d15c5",
      "metadata": {
        "id": "1e7d15c5"
      },
      "outputs": [],
      "source": [
        "harmful_sample_toks = instructions_to_chat_tokens(\n",
        "    tokenizer=model.tokenizer, instructions=harmful_inst_train[:2]\n",
        ")\n",
        "harmless_sample_toks = instructions_to_chat_tokens(\n",
        "    tokenizer=model.tokenizer, instructions=harmless_inst_train[:2]\n",
        ")\n",
        "\n",
        "#  Important to note how each sample looks like\n",
        "for sample in harmful_sample_toks[:2]:\n",
        "    print(model.tokenizer.decode(sample))\n",
        "    print(\"-\" * 50)\n",
        "for sample in harmless_sample_toks[:2]:\n",
        "    print(model.tokenizer.decode(sample))\n",
        "    print(\"-\" * 50)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "254834b1",
      "metadata": {
        "id": "254834b1"
      },
      "outputs": [],
      "source": [
        "def __run_with_cache(model, data, batch_size):\n",
        "    cache = {}\n",
        "    with torch.no_grad():\n",
        "        for i in range(0, len(data), batch_size):\n",
        "            #  For this particular function, with respect to input (B, L), it generates the intermediate activations for each block (pre, mid, post), so output will be layer 3*layer keys, each entry is (B, L, D)\n",
        "            _, batch_cache = model.run_with_cache(\n",
        "                data[i : i + batch_size],\n",
        "                names_filter=lambda hook_name: \"resid\" in hook_name,\n",
        "                return_cache_object=False,\n",
        "            )\n",
        "            for k, v in batch_cache.items():\n",
        "                if k not in cache:\n",
        "                    cache[k] = v.cpu()\n",
        "                else:\n",
        "                    cache[k] = torch.vstack([cache[k], v.cpu()])\n",
        "    #  Cache Keys are each intermediate location; Each entry is (B, L, D) where B is the total number of prompts\n",
        "    return ActivationCache(cache, model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2d76f2dd",
      "metadata": {
        "id": "2d76f2dd"
      },
      "outputs": [],
      "source": [
        "def get_template_suffix_toks(tokenizer):\n",
        "    # Since the padding is on the left side, the suffix of all samples are the same\n",
        "    # when using the same template.\n",
        "    # The activations on these suffix tokens are after the prompt has been processed,\n",
        "    # thus it's interesting to see how the activations differ between contrastive\n",
        "    # samples\n",
        "\n",
        "    # get the common suffix between 2 samples\n",
        "    toks = instructions_to_chat_tokens(tokenizer=tokenizer, instructions=[\"a\", \"b\"])\n",
        "\n",
        "    suffix = toks[0]\n",
        "    #  We traverse backwards\n",
        "    for i in range(len(toks[0]) - 1, -1, -1):\n",
        "        #  technically it collects the toks from the earliest differing token, but since instruction list is fix to only a and b, it will determinstically obtain the following:\n",
        "        #  \"Qwen/Qwen2.5-3B-Instruct\" from ids -> token: ['<|im_end|>', 'Ċ', '<|im_start|>', 'assistant', 'Ċ']\n",
        "        if toks[0][i] != toks[1][i]:\n",
        "            suffix = toks[0][i + 1 :]\n",
        "\n",
        "    return tokenizer.convert_ids_to_tokens(suffix)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4c46e969",
      "metadata": {
        "id": "4c46e969"
      },
      "outputs": [],
      "source": [
        "def get_activations(\n",
        "    model: HookedTransformer,\n",
        "    instructions: List[str],\n",
        "    batch_size: int = BATCH_SIZE,\n",
        "    act_names: List[str] = [\"resid_mid\", \"resid_post\"], #  The Positions we're interested in\n",
        "    num_last_tokens: int = 1,\n",
        "):\n",
        "    # tokenize instructions\n",
        "    toks = instructions_to_chat_tokens(\n",
        "        tokenizer=model.tokenizer, instructions=instructions\n",
        "    )\n",
        "\n",
        "    # run model on instructions and cache activations\n",
        "    with torch.no_grad():\n",
        "        cache = __run_with_cache(model, toks, batch_size=BATCH_SIZE)\n",
        "\n",
        "    # get activations for the last n tokens\n",
        "    acts = torch.stack(\n",
        "        [\n",
        "            torch.stack(\n",
        "                [cache[act, layer][:, -num_last_tokens:, :] for act in act_names]\n",
        "            )\n",
        "            for layer in range(model.cfg.n_layers)\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    # For acts: layers x resid_modules [ mid, post] x batch [ prompts] x tokens x dim\n",
        "    return acts, cache"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a7d1c819",
      "metadata": {
        "id": "a7d1c819"
      },
      "outputs": [],
      "source": [
        "def get_CAA_output_hook(layer, direction: Tensor):\n",
        "    def hook_fn(output, hook):\n",
        "        #  We want to ablate dest -> src  (Input should be negative of harmful - harmless)\n",
        "        nonlocal direction\n",
        "        # nonlocal direction\n",
        "\n",
        "        #  Obtain Activations (Might be a tuple so we obtain the activation component)\n",
        "        if isinstance(output, tuple):\n",
        "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output[0]\n",
        "        else:\n",
        "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output\n",
        "\n",
        "        # ActAdd alpha_param\n",
        "        alpha = 1.0\n",
        "        #  Normalize the direction (dir is 1D vector)\n",
        "        direction = direction / (direction.norm(p = 2) + 1e-8)\n",
        "        direction = direction.to(activation)\n",
        "        activation += alpha * direction\n",
        "\n",
        "        if isinstance(output, tuple):\n",
        "            return (activation, *output[1:])\n",
        "        else:\n",
        "            return activation\n",
        "\n",
        "    return hook_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4d6e19b2",
      "metadata": {
        "id": "4d6e19b2"
      },
      "outputs": [],
      "source": [
        "import tqdm\n",
        "N_INST_TRAIN = 512\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "# extraction points per decoder block\n",
        "act_names = [\"resid_mid\", \"resid_post\"]\n",
        "\n",
        "# get the template suffix tokens\n",
        "template_suffix_toks = get_template_suffix_toks(model.tokenizer)\n",
        "if not template_suffix_toks:\n",
        "    template_suffix_toks = [\"<last token>\"]\n",
        "\n",
        "# only get the activations of the template suffix tokens since these tokens are the same\n",
        "# for all samples\n",
        "#  The causal nature implies the output of these tokens already contain the information of the prompts itself.\n",
        "num_last_tokens = len(template_suffix_toks)\n",
        "print(\"template_suffix_toks:\", template_suffix_toks)\n",
        "\n",
        "#  File Path Names\n",
        "chosen_token = -1\n",
        "refusal_dirs_path = (\n",
        "    OUTPUT_DIR\n",
        "    / f\"refusal_dirs_{chosen_token}_{LANGUAGE}.npy\"\n",
        ")\n",
        "unnormed_refusal_dirs_path = (\n",
        "    OUTPUT_DIR\n",
        "    / f\"refusal_dirs_unnormed_{chosen_token}_{LANGUAGE}.npy\"\n",
        ")\n",
        "\n",
        "#  Harmful should be left alone since we don't steer it and it will be deterministic; Furthermore, this is not unique as the model isn't steered\n",
        "output_harmful_file = OUTPUT_PARENT_DIR / f\"acts_harmful_{LANGUAGE}.npy\"\n",
        "\n",
        "#  Harmless should be rerun during learning, and this is unique to each to each choice of beta\n",
        "output_harmless_file = OUTPUT_DIR / f\"acts_harmless_{LANGUAGE}.npy\"\n",
        "\n",
        "if refusal_dirs_path.exists() and unnormed_refusal_dirs_path.exists() and output_harmless_file.exists():\n",
        "    print(\"loading refusal_dirs and unnormed refusal_dirs from file\")\n",
        "    unnormed_refusal_dirs = np.load(unnormed_refusal_dirs_path)\n",
        "    refusal_dirs = np.load(refusal_dirs_path)\n",
        "    print(\"loading harmful and sequentially steered harmless files\")\n",
        "    harmful_acts = np.load(output_harmful_file)\n",
        "    harmful_acts = torch.from_numpy(harmful_acts)\n",
        "    harmless_acts = np.load(output_harmless_file)\n",
        "    harmless_acts = torch.from_numpy(harmless_acts)\n",
        "else:\n",
        "    # Momentum_mode\n",
        "    momentum_mode = False\n",
        "    v = None\n",
        "\n",
        "    #  Do not apply steering to harmful activation, so running them outside to save computations\n",
        "    harmful_acts, cache = get_activations(\n",
        "        model,\n",
        "        harmful_inst_train[:N_INST_TRAIN],\n",
        "        batch_size=BATCH_SIZE,\n",
        "        act_names=act_names,\n",
        "        num_last_tokens=num_last_tokens,\n",
        "    )\n",
        "    np.save(output_harmful_file, harmful_acts.cpu().float().numpy())\n",
        "\n",
        "    # Normalize and Mean across harmful\n",
        "    harmful_acts_norm = harmful_acts / harmful_acts.norm(dim=-1, keepdim=True)\n",
        "    harmful_acts_norm_mean = harmful_acts_norm.mean(dim = 2)\n",
        "\n",
        "    # Direction will store [layer][mid == 0, post == 1]\n",
        "    directions = {}\n",
        "\n",
        "    # Initialize refusal_dirs storage:\n",
        "    d_model = harmful_acts.shape[-1]\n",
        "    unnormed_refusal_dirs = torch.zeros(model.cfg.n_layers * len(act_names), d_model)\n",
        "\n",
        "    # Used Module Name, each format should be (layer_ind, (0/1))\n",
        "    used_module_names = []\n",
        "    for l in tqdm.tqdm(range(model.cfg.n_layers * len(act_names))):\n",
        "\n",
        "        layer, pos = l // 2, l % 2\n",
        "\n",
        "        #  Forward Hook Construction; Separated for layer and position:\n",
        "        fwd_hooks = []\n",
        "        for tup in used_module_names:\n",
        "            ly, mp = tup\n",
        "            # Note that each direction is from src (harmless) to (harmful);\n",
        "            # We add, so we take positive direction\n",
        "            if mp == 0:\n",
        "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_mid\", get_CAA_output_hook(ly, directions[ly][mp])))\n",
        "            elif mp == 1:\n",
        "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_post\", get_CAA_output_hook(ly, directions[ly][mp])))\n",
        "            else:\n",
        "                raise NotImplementedError(\"mp not 0 or 1\")\n",
        "        # fwd_hooks = [(f\"blocks.{ly}.hook_resid_post\", get_direction_ablation_output_hook(ly, directions[ly])) for ly in used_module_names]\n",
        "            # These hooks will intervene on the model, that's why we need hook_* args.\n",
        "        # get contranstive activations\n",
        "        #  Apply steering to only Harmless Activation (Src)\n",
        "        with model.hooks(fwd_hooks=fwd_hooks):\n",
        "            harmless_acts, cache = get_activations(\n",
        "                model,\n",
        "                harmless_inst_train[:N_INST_TRAIN],\n",
        "                batch_size=BATCH_SIZE,\n",
        "                act_names=act_names,\n",
        "                num_last_tokens=num_last_tokens,\n",
        "            )\n",
        "\n",
        "\n",
        "        # print(harmful_acts.shape)\n",
        "        # print(harmless_acts.shape)\n",
        "\n",
        "        # For each step, we find the difference of normed means (Norm the activations, then take the mean) then we find the refusal direction\n",
        "        # Take the Mean across Batch\n",
        "        harmless_acts_norm = harmless_acts / harmless_acts.norm(dim=-1, keepdim=True)\n",
        "\n",
        "        # Take Mean across Batch\n",
        "        harmless_acts_norm_mean = harmless_acts_norm.mean(dim = 2)\n",
        "\n",
        "        # Take the difference in normed mean (For now, our ref_dir is from harmless -> harmful; but in our directional ablation we take the opposite direction)\n",
        "        ref_dir_set = harmful_acts_norm_mean - harmless_acts_norm_mean\n",
        "        # Specifically, choose the last token\n",
        "        d_model = ref_dir_set.shape[-1]\n",
        "        ref_dir_set = ref_dir_set[:, :, -1].reshape(-1, d_model)\n",
        "        ref_dir = ref_dir_set[l] # This should be = to layer*2 + pos\n",
        "        # seq_harmful_acts_normed = seq_harmful_acts / seq_harmful_acts.norm(dim=-1, keepdim=True)\n",
        "        # seq_harmless_acts_normed = seq_harmless_acts / seq_harmless_acts.norm(dim=-1, keepdim=True)\n",
        "        if pos == 0:\n",
        "            directions[layer] = {}\n",
        "        if momentum_mode:\n",
        "            if l == 0:\n",
        "                v = ref_dir\n",
        "            else:\n",
        "                v = beta * v + ref_dir\n",
        "            directions[layer][pos] = v\n",
        "            unnormed_refusal_dirs[l, :] = v.detach()\n",
        "        else:\n",
        "            directions[layer][pos] = ref_dir\n",
        "            unnormed_refusal_dirs[l, :] = ref_dir.detach()\n",
        "        # harmful_acts_normed_mean = seq_harmful_acts_normed.mean(dim=2)\n",
        "        # harmless_acts_normed_mean = seq_harmless_acts_normed.mean(dim=2)\n",
        "        used_module_names.append((layer, pos))\n",
        "\n",
        "    #  Compute Normed Refusal Directions\n",
        "    refusal_dirs = unnormed_refusal_dirs / unnormed_refusal_dirs.norm(dim=-1, keepdim=True)\n",
        "    refusal_dirs = refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
        "    unnormed_refusal_dirs = unnormed_refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
        "\n",
        "    #  Set harmless + harmful acts to cpu and float\n",
        "    harmful_acts = harmful_acts.cpu().float()\n",
        "    harmless_acts = harmless_acts.cpu().float()\n",
        "\n",
        "    # Save harmless, refusal directions normed and unnormed\n",
        "    np.save(output_harmless_file, harmless_acts.numpy())\n",
        "    np.save(unnormed_refusal_dirs_path, unnormed_refusal_dirs)\n",
        "    np.save(refusal_dirs_path, refusal_dirs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "02ce7145",
      "metadata": {
        "id": "02ce7145"
      },
      "outputs": [],
      "source": [
        "colour_map = {\n",
        "    \"harmless\": plotly.colors.qualitative.Plotly[0],\n",
        "    \"harmful\": plotly.colors.qualitative.Plotly[1],\n",
        "    \"neutral\": plotly.colors.qualitative.Plotly[3],\n",
        "}\n",
        "colour_map_light = {\n",
        "    \"harmless\": plotly.colors.qualitative.Pastel1[1],\n",
        "    \"harmful\": plotly.colors.qualitative.Pastel1[0],\n",
        "    \"neutral\": plotly.colors.qualitative.Pastel1[3],\n",
        "}\n",
        "colour_map_opaque = {\n",
        "    \"harmless\": None,\n",
        "    \"harmful\": \"rgba(251, 180, 174, 0.3)\",\n",
        "    \"harmless\": \"rgba(179, 205, 227, 0.3)\",\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "734ce001",
      "metadata": {
        "id": "734ce001"
      },
      "outputs": [],
      "source": [
        "from torch.nn.functional import normalize\n",
        "\n",
        "\n",
        "\n",
        "# raw_dirs = harmful_acts_normed_mean_normed - harmless_acts_normed_mean_normed\n",
        "raw_dirs = torch.asarray(unnormed_refusal_dirs)\n",
        "\n",
        "num_layers = raw_dirs.shape[0]\n",
        "\n",
        "raw_dirs = raw_dirs.reshape((-1, raw_dirs.shape[-1]))\n",
        "\n",
        "layer_names = [str(i) for i in range(2 * num_layers)]\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names,\n",
        "        y=raw_dirs.norm(dim=-1),\n",
        "        mode=\"lines+markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map_light[\"neutral\"],\n",
        "        marker_size=8,\n",
        "        showlegend=False,\n",
        "    )\n",
        ")\n",
        "fig.add_trace(\n",
        "    go.Scatter(\n",
        "        x=layer_names,\n",
        "        y=raw_dirs.norm(dim=-1)[0::],\n",
        "        mode=\"markers\",\n",
        "        yaxis=\"y\",\n",
        "        marker_color=colour_map[\"harmful\"],\n",
        "        marker_size=8,\n",
        "        showlegend=False,\n",
        "    )\n",
        ")\n",
        "\n",
        "print(layer_names[np.argmax(raw_dirs.norm(dim=-1)[:-1])])\n",
        "\n",
        "\n",
        "fig.update_layout(\n",
        "    # title=(\n",
        "    #     \"Statistics of refusal direction candidates at each layer\"\n",
        "    #     f\" layer for {MODEL_PATH}\"\n",
        "    # ),\n",
        "    plot_bgcolor=\"white\",\n",
        "    grid=dict(rows=1, columns=1),\n",
        "    xaxis=dict(\n",
        "        type=\"category\",\n",
        "        title=dict(text=\"Extraction Point\", font=dict(size=28)),\n",
        "        dtick=4,\n",
        "        gridcolor=\"lightgrey\",\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    yaxis=dict(\n",
        "        title=dict(text=\"Norm of<br>Refusal Direction\", font=dict(size=28)),\n",
        "        gridcolor=\"lightgrey\",\n",
        "        zeroline=False,\n",
        "        tickfont=dict(size=24),\n",
        "    ),\n",
        "    hovermode=\"x unified\",\n",
        "    height=300,\n",
        "    width=1000,\n",
        "    # width=20 + 12 * len(x_values),\n",
        "    margin=dict(l=20, r=20, t=20, b=20),\n",
        ")\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(VISUALIZATION_DIR / f\"norm_refusal.pdf\", scale=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7ed277a0",
      "metadata": {
        "id": "7ed277a0"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "angular_steering",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.18"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}