{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "82acAhWYGIPx"
   },
   "source": [
    "# Angular Steering\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j7hOtw7UHXdD"
   },
   "source": [
    "This notebook contains:\n",
    "\n",
    "- The process of extracting the refusal direction and constructing the steering plane.\n",
    "- Visualization of the activation, extracted directions and constructed steering planes.\n",
    "- The creation of the steering config that can be used with our fork of vLLM.\n",
    "\n",
    "This notebook is inspired by https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fcxHyDZw6b86"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dLeei4-T6Wef"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama nbformat plotly datasets pandas scikit-learn matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_vhhwl-2-jPg"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import functools\n",
    "import einops\n",
    "import requests\n",
    "import pandas as pd\n",
    "import io\n",
    "import textwrap\n",
    "import gc\n",
    "import numpy as np\n",
    "import plotly\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "from pathlib import Path\n",
    "from datasets import load_dataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from tqdm import tqdm\n",
    "from torch import Tensor\n",
    "from typing import List\n",
    "from transformer_lens import HookedTransformer, utils, ActivationCache\n",
    "from transformer_lens.hook_points import HookPoint\n",
    "from transformers import AutoTokenizer\n",
    "from jaxtyping import Float, Int\n",
    "from colorama import Fore\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Only use GPU 0 and 2\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5\"\n",
    "\n",
    "# Must set BEFORE importing torch\n",
    "\n",
    "print(torch.cuda.device_count())  # should show 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6ZOoJagxD49V"
   },
   "source": [
    "### Models and configs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 191,
     "referenced_widgets": [
      "ad063e2c68a44f009bfab68c141c09be",
      "89ee88168c474e9fbcf4a17f1483eff4",
      "3877270cf4bc42a9b6142cce7a5d8c54",
      "9a5611a341ed4673aaaf2f463f685d7c",
      "a2de63dfbd6c485e841c6fcd1fefe451",
      "c362d50107dd4a2db0d1a79da2af8d57",
      "ffa85c694b694425999187b346c7ecfe",
      "ec8f6f360a2243b0ac98d34e825ba378",
      "f2ee188bfaa84e9680dbc296b1adbef6",
      "e973493cd6d14381bb4ad2f82417e8a9",
      "89797f6e82104058af92e3ceb094af66"
     ]
    },
    "id": "Vnp65Vsg5x-5",
    "outputId": "25fb5805-fe31-44b0-8f73-6fabc230d261"
   },
   "outputs": [],
   "source": [
    "# Choose one model for the experimentw\n",
    "MODEL_PATH = (\n",
    "    \"Qwen/Qwen2.5-3B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-7B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "    # \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "    # \"meta-llama/Llama-3.1-8B-Instruct\"\n",
    "    # \"google/gemma-2-9b-it\" \n",
    "    # \"openai/gpt-oss-20b\n",
    "    # \"google/gemma-2b\"\n",
    ")\n",
    "\n",
    "\n",
    "p_coe = 1.0\n",
    "i_coe = 2.5 \n",
    "d_coe = 0.0\n",
    "\n",
    "# METHOD_PREFIX = \"PID_\"\n",
    "# METHOD_SFX = \"_PID\"\n",
    "\n",
    "# METHOD_PREFIX = \"RePE_\"\n",
    "# METHOD_SFX = \"_RePE\"\n",
    "\n",
    "# METHOD_PREFIX = \"ITI_\"\n",
    "# METHOD_SFX = \"_ITI\"\n",
    "\n",
    "METHOD_PREFIX = \"\"\n",
    "METHOD_SFX = \"\"\n",
    "\n",
    "MODEL_NAME = MODEL_PATH.split(\"/\")[-1]\n",
    "\n",
    "DEVICE = \"cuda:0\"\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "OUTPUT_DIR = Path(\"output\") / f\"{MODEL_NAME}\"\n",
    "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "VISUALIZATION_DIR = Path(\"visualization\") / f\"{MODEL_NAME}\"\n",
    "VISUALIZATION_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES\n",
    "\n",
    "if MODEL_PATH not in OFFICIAL_MODEL_NAMES:\n",
    "    OFFICIAL_MODEL_NAMES.append(MODEL_PATH)\n",
    "\n",
    "model = HookedTransformer.from_pretrained_no_processing(\n",
    "    MODEL_PATH,\n",
    "    device=DEVICE,\n",
    "    dtype=torch.bfloat16,\n",
    "    default_padding_side=\"left\",\n",
    "    # bf16=True\n",
    ")\n",
    "\n",
    "model.tokenizer.padding_side = \"left\"\n",
    "\n",
    "# store original chat template\n",
    "ORIGINAL_CHAT_TEMPLATE = model.tokenizer.chat_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# handle pad token for some model\n",
    "if not model.tokenizer.pad_token:\n",
    "    if \"qwen1\" in MODEL_PATH.lower():\n",
    "        model.tokenizer.pad_token = \"<|endoftext|>\"\n",
    "    elif model.tokenizer.eos_token:\n",
    "        model.tokenizer.pad_token = model.tokenizer.eos_token\n",
    "    else:\n",
    "        raise ValueError(\"No pad token found in the tokenizer.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rF7e-u20EFTe"
   },
   "source": [
    "### Load harmful / harmless datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5i1XcVIgHEE1"
   },
   "outputs": [],
   "source": [
    "def get_harmful_instructions():\n",
    "    url = \"https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv\"\n",
    "    response = requests.get(url)\n",
    "\n",
    "    dataset = pd.read_csv(io.StringIO(response.content.decode(\"utf-8\")))\n",
    "    instructions = dataset[\"goal\"].tolist()\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmless_instructions():\n",
    "    hf_path = \"tatsu-lab/alpaca\"\n",
    "    dataset = load_dataset(hf_path)\n",
    "\n",
    "    # filter for instructions that do not have inputs\n",
    "    instructions = []\n",
    "    for i in range(len(dataset[\"train\"])):\n",
    "        if dataset[\"train\"][i][\"input\"].strip() == \"\":\n",
    "            instructions.append(dataset[\"train\"][i][\"instruction\"])\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmful_instructions_jp():\n",
    "    huggingface_id = \"kunishou/do-not-answer-ja\"\n",
    "    dataset = load_dataset(huggingface_id)\n",
    "\n",
    "    selected_types = [\n",
    "        # # Information Hazards\n",
    "        # \"Risks from leaking or inferring sensitive information (organization/gov)\",\n",
    "        # (\n",
    "        #     \"Compromise privacy by leaking or inferring private information\"\n",
    "        #     \" (person/individual)\"\n",
    "        # ),\n",
    "        # Malicious Uses\n",
    "        \"Assisting illegal activities\",\n",
    "        \"Nudging or advising users to perform unethical or unsafe actions\",\n",
    "        \"Reducing the cost of disinformation campaigns\",\n",
    "        # Discrimination, Exclusion, Toxicity, Hateful, Offensive\n",
    "        \"Social stereotypes and unfair discrimination\",\n",
    "        \"Toxic language (hate speech)\",\n",
    "        \"Adult Content\",\n",
    "        # # Misinformation Harms\n",
    "        # \"Causing material harm by disseminating misinformation e.g. in medicine or law\",\n",
    "        # \"Disseminating false or misleading information\",\n",
    "        # # Human-Chatbot Interaction Harms\n",
    "        # \"Mental Health or Overreliance Crisis\",\n",
    "        # \"Treat Chatbot as a Human\",\n",
    "    ]\n",
    "\n",
    "    instructions = []\n",
    "    for item in dataset[\"train\"]:\n",
    "        if item[\"types_of_harm\"] not in selected_types:\n",
    "            continue\n",
    "        instructions.append(item[\"question\"])\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmless_instructions_jp():\n",
    "    huggingface_id = \"Lazycuber/alpaca-jp\"\n",
    "    dataset = load_dataset(huggingface_id)\n",
    "\n",
    "    # filter for instructions that do not have inputs\n",
    "    instructions = []\n",
    "    for item in dataset[\"train\"]:\n",
    "        if item[\"input\"].strip() != \"\":\n",
    "            continue\n",
    "        inst = item[\"instruction\"]\n",
    "        inst = inst.strip(\"「」'\")\n",
    "        instructions.append(inst)\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Rth8yvLZJsXs"
   },
   "outputs": [],
   "source": [
    "LANGUAGE = \"en\"\n",
    "\n",
    "if LANGUAGE == \"en\":\n",
    "    harmful_inst_train, harmful_inst_test = get_harmful_instructions()\n",
    "    harmless_inst_train, harmless_inst_test = get_harmless_instructions()\n",
    "elif LANGUAGE == \"jp\":\n",
    "    harmful_inst_train, harmful_inst_test = get_harmful_instructions_jp()\n",
    "    harmless_inst_train, harmless_inst_test = get_harmless_instructions_jp()\n",
    "\n",
    "print(f\"Train: {len(harmful_inst_train)} harmful, {len(harmless_inst_train)} harmless\")\n",
    "print(f\"Test: {len(harmful_inst_test)} harmful, {len(harmless_inst_test)} harmless\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Qv2ALDY_J44G",
    "outputId": "7cce7654-9592-4414-a322-27c6631f2e8e"
   },
   "outputs": [],
   "source": [
    "print(\"Harmful instructions:\")\n",
    "for i in range(4):\n",
    "    print(f\"\\t{harmful_inst_train[i]}\")\n",
    "print(\"Harmless instructions:\")\n",
    "for i in range(4):\n",
    "    print(f\"\\t{harmless_inst_train[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KOKYA61k8LWt"
   },
   "source": [
    "### Tokenization utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P8UPQSfpWOSK"
   },
   "outputs": [],
   "source": [
    "def instructions_to_chat_tokens(\n",
    "    tokenizer: AutoTokenizer,\n",
    "    instructions: List[str],\n",
    ") -> Int[Tensor, \"batch_size seq_len\"]:\n",
    "    if tokenizer.chat_template:\n",
    "        convos = [\n",
    "            [{\"role\": \"user\", \"content\": instruction}] for instruction in instructions\n",
    "        ]\n",
    "        return tokenizer.apply_chat_template(\n",
    "            convos,\n",
    "            padding=True,\n",
    "            truncation=False,\n",
    "            add_generation_prompt=True,\n",
    "            return_tensors=\"pt\",\n",
    "        )\n",
    "    else:\n",
    "        return tokenizer(\n",
    "            instructions, padding=True, truncation=False, return_tensors=\"pt\"\n",
    "        ).input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "harmful_sample_toks = instructions_to_chat_tokens(\n",
    "    tokenizer=model.tokenizer, instructions=harmful_inst_train[:2]\n",
    ")\n",
    "harmless_sample_toks = instructions_to_chat_tokens(\n",
    "    tokenizer=model.tokenizer, instructions=harmless_inst_train[:2]\n",
    ")\n",
    "\n",
    "for sample in harmful_sample_toks[:2]:\n",
    "    print(model.tokenizer.decode(sample))\n",
    "    print(\"-\" * 50)\n",
    "for sample in harmless_sample_toks[:2]:\n",
    "    print(model.tokenizer.decode(sample))\n",
    "    print(\"-\" * 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gtrIK8x78SZh"
   },
   "source": [
    "### Generation utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "94jRJDR0DRoY"
   },
   "outputs": [],
   "source": [
    "def _generate_with_hooks(\n",
    "    model: HookedTransformer,\n",
    "    toks: Int[Tensor, \"batch_size seq_len\"],\n",
    "    max_tokens_generated: int = BATCH_SIZE,\n",
    "    fwd_hooks=[],\n",
    ") -> List[str]:\n",
    "\n",
    "    all_toks = torch.zeros(\n",
    "        (toks.shape[0], toks.shape[1] + max_tokens_generated),\n",
    "        dtype=torch.long,\n",
    "        device=toks.device,\n",
    "    )\n",
    "    all_toks[:, : toks.shape[1]] = toks\n",
    "\n",
    "    for i in range(max_tokens_generated):\n",
    "        with model.hooks(fwd_hooks=fwd_hooks):\n",
    "            logits = model(all_toks[:, : -max_tokens_generated + i])\n",
    "            next_tokens = logits[:, -1, :].argmax(\n",
    "                dim=-1\n",
    "            )  # greedy sampling (temperature=0)\n",
    "            all_toks[:, -max_tokens_generated + i] = next_tokens\n",
    "\n",
    "    return model.tokenizer.batch_decode(\n",
    "        all_toks[:, toks.shape[1] :], skip_special_tokens=False\n",
    "    )\n",
    "\n",
    "\n",
    "def get_generations(\n",
    "    model: HookedTransformer,\n",
    "    instructions: List[str],\n",
    "    tokenizer: AutoTokenizer,\n",
    "    fwd_hooks=[],\n",
    "    max_tokens_generated: int = 64,\n",
    "    batch_size: int = BATCH_SIZE,\n",
    ") -> List[str]:\n",
    "\n",
    "    generations = []\n",
    "\n",
    "    for i in tqdm(range(0, len(instructions), batch_size)):\n",
    "        toks = instructions_to_chat_tokens(\n",
    "            tokenizer=tokenizer, instructions=instructions[i : i + batch_size]\n",
    "        )\n",
    "\n",
    "        with torch.no_grad():\n",
    "            generation = _generate_with_hooks(\n",
    "                model,\n",
    "                toks,\n",
    "                max_tokens_generated=max_tokens_generated,\n",
    "                fwd_hooks=fwd_hooks,\n",
    "            )\n",
    "        generations.extend(generation)\n",
    "\n",
    "    return generations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_single_sample(model, input, tokenizer, fwd_hooks=[], max_tokens_generated=64):\n",
    "    baseline_generations = get_generations(\n",
    "        model,\n",
    "        [input],\n",
    "        tokenizer,\n",
    "        fwd_hooks=[],\n",
    "        max_tokens_generated=max_tokens_generated,\n",
    "    )\n",
    "    intervention_generations = get_generations(\n",
    "        model,\n",
    "        [input],\n",
    "        tokenizer,\n",
    "        fwd_hooks=fwd_hooks,\n",
    "        max_tokens_generated=max_tokens_generated,\n",
    "    )\n",
    "\n",
    "    print(f\"INSTRUCTION: {repr(input)}\")\n",
    "    print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
    "    print(\n",
    "        textwrap.fill(\n",
    "            baseline_generations[0],\n",
    "            width=100,\n",
    "            initial_indent=\"\\t\",\n",
    "            subsequent_indent=\"\\t\",\n",
    "        )\n",
    "    )\n",
    "    print(Fore.RED + f\"INTERVENTION COMPLETION:\")\n",
    "    print(\n",
    "        textwrap.fill(\n",
    "            intervention_generations[0],\n",
    "            width=100,\n",
    "            initial_indent=\"\\t\",\n",
    "            subsequent_indent=\"\\t\",\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W9O8dm0_EQRk"
   },
   "source": [
    "## Finding the \"refusal direction\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MbY79kSP8oOg"
   },
   "outputs": [],
   "source": [
    "def __run_with_cache(model, data, batch_size):\n",
    "    cache = {}\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, len(data), batch_size):\n",
    "            _, batch_cache = model.run_with_cache(\n",
    "                data[i : i + batch_size],\n",
    "                names_filter=lambda hook_name: \"resid\" in hook_name,\n",
    "                return_cache_object=False,\n",
    "            )\n",
    "            for k, v in batch_cache.items():\n",
    "                if k not in cache:\n",
    "                    cache[k] = v.cpu()\n",
    "                else:\n",
    "                    cache[k] = torch.vstack([cache[k], v.cpu()])\n",
    "\n",
    "    return ActivationCache(cache, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_template_suffix_toks(tokenizer):\n",
    "    # Since the padding is on the left side, the suffix of all samples are the same\n",
    "    # when using the same template.\n",
    "    # The activations on these suffix tokens are after the prompt has been processed,\n",
    "    # thus it's interesting to see how the activations differ between contrastive\n",
    "    # samples\n",
    "\n",
    "    # get the common suffix between 2 samples\n",
    "    toks = instructions_to_chat_tokens(tokenizer=tokenizer, instructions=[\"a\", \"b\"])\n",
    "    suffix = toks[0]\n",
    "    for i in range(len(toks[0]) - 1, -1, -1):\n",
    "        if toks[0][i] != toks[1][i]:\n",
    "            suffix = toks[0][i + 1 :]\n",
    "\n",
    "    return tokenizer.convert_ids_to_tokens(suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_activations(\n",
    "    model: HookedTransformer,\n",
    "    instructions: List[str],\n",
    "    batch_size: int = BATCH_SIZE,\n",
    "    act_names: List[str] = [\"resid_mid\", \"resid_post\"],\n",
    "    num_last_tokens: int = 1,\n",
    "):\n",
    "    # tokenize instructions\n",
    "    toks = instructions_to_chat_tokens(\n",
    "        tokenizer=model.tokenizer, instructions=instructions\n",
    "    )\n",
    "    # toks.to(model.device())\n",
    "    # run model on instructions and cache activations\n",
    "    with torch.no_grad():\n",
    "        cache = __run_with_cache(model, toks, batch_size=BATCH_SIZE)\n",
    "\n",
    "    # get activations for the last n tokens\n",
    "    acts = torch.stack(\n",
    "        [\n",
    "            torch.stack(\n",
    "                [cache[act, layer][:, -num_last_tokens:, :] for act in act_names]\n",
    "            )\n",
    "            for layer in range(model.cfg.n_layers)\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    # layers x resid_modules x batch x tokens x dim\n",
    "    return acts, cache"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract the activations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_CAA_output_hook(layer, direction: Tensor):\n",
    "    def hook_fn(output, hook):\n",
    "        # RL: We want to ablate dest -> src  (Input should be negative of harmful - harmless)\n",
    "        nonlocal direction\n",
    "        # nonlocal direction\n",
    "\n",
    "        # RL: Obtain Activations (Might be a tuple so we obtain the activation component)\n",
    "        if isinstance(output, tuple):\n",
    "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output[0]\n",
    "        else:\n",
    "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output\n",
    "\n",
    "        # ActAdd alpha_param\n",
    "        alpha = 1.0\n",
    "        # RL: Normalize the direction (dir is 1D vector)\n",
    "        direction = direction / (direction.norm(p = 2) + 1e-8)  \n",
    "        direction = direction.to(activation)\n",
    "        activation += alpha * direction\n",
    "\n",
    "        if isinstance(output, tuple):\n",
    "            return (activation, *output[1:])\n",
    "        else:\n",
    "            return activation\n",
    "\n",
    "    return hook_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "str(i_coe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OUTPUT_PARENT_DIR = Path(\"output\") / f\"{MODEL_NAME}\" / f\"causal_noise\"\n",
    "\n",
    "OUTPUT_DIR = OUTPUT_PARENT_DIR / (f\"pid_\" + str(p_coe).replace(\".\", \"p\") + \"_\" + str(i_coe).replace(\".\", \"p\") + \"_\" + str(d_coe).replace(\".\", \"p\"))\n",
    "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "N_INST_TRAIN = 512\n",
    "BATCH_SIZE = 512\n",
    "\n",
    "# extraction points per decoder block\n",
    "act_names = [\"resid_mid\", \"resid_post\"]\n",
    "\n",
    "# get the template suffix tokens\n",
    "template_suffix_toks = get_template_suffix_toks(model.tokenizer)\n",
    "if not template_suffix_toks:\n",
    "    template_suffix_toks = [\"<last token>\"]\n",
    "\n",
    "# only get the activations of the template suffix tokens since these tokens are the same\n",
    "# for all samples\n",
    "# RL: The causal nature implies the output of these tokens already contain the information of the prompts itself.\n",
    "num_last_tokens = len(template_suffix_toks)\n",
    "print(\"template_suffix_toks:\", template_suffix_toks)\n",
    "\n",
    "# RL: File Path Names\n",
    "chosen_token = -1\n",
    "refusal_dirs_path = (\n",
    "    OUTPUT_DIR\n",
    "    / f\"refusal_dirs_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    ")\n",
    "unnormed_refusal_dirs_path = (\n",
    "    OUTPUT_DIR\n",
    "    / f\"refusal_dirs_unnormed_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    ")\n",
    "\n",
    "p_dirs_path = (\n",
    "    OUTPUT_DIR\n",
    "    / f\"p_dirs_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    ")\n",
    "\n",
    "i_dirs_path = (\n",
    "    OUTPUT_DIR\n",
    "    / f\"i_dirs_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    ")\n",
    "\n",
    "d_dirs_path = (\n",
    "    OUTPUT_DIR\n",
    "    / f\"d_dirs_{chosen_token}_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    ")\n",
    "# RL: Harmful should be left alone since we don't steer it and it will be deterministic; Furthermore, this is not unique as the model isn't steered\n",
    "output_harmful_file = OUTPUT_DIR / f\"acts_harmful_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    "\n",
    "# RL: Harmless should be rerun during learning, and this is unique to each to each choice of beta\n",
    "output_harmless_file = OUTPUT_DIR / f\"acts_harmless_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    "\n",
    "if refusal_dirs_path.exists() and unnormed_refusal_dirs_path.exists() and output_harmless_file.exists():\n",
    "    print(\"loading refusal_dirs and unnormed refusal_dirs from file\")\n",
    "    unnormed_refusal_dirs = np.load(unnormed_refusal_dirs_path)\n",
    "    refusal_dirs = np.load(refusal_dirs_path)\n",
    "    print(\"loading harmful and sequentially steered harmless files\")\n",
    "    harmful_acts = np.load(output_harmful_file)\n",
    "    harmful_acts = torch.from_numpy(harmful_acts)\n",
    "    harmless_acts = np.load(output_harmless_file)\n",
    "    harmless_acts = torch.from_numpy(harmless_acts)\n",
    "else:\n",
    "    # Momentum_mode\n",
    "    momentum_mode = True\n",
    "    v = None\n",
    "\n",
    "    # RL: Do not apply steering to harmful activation, so running them outside to save computations\n",
    "    if output_harmful_file.exists():\n",
    "        harmful_acts = np.load(output_harmful_file)\n",
    "        harmful_acts = torch.from_numpy(harmful_acts)\n",
    "    else:\n",
    "        harmful_acts, cache = get_activations(\n",
    "            model,\n",
    "            harmful_inst_train[:N_INST_TRAIN],\n",
    "            batch_size=BATCH_SIZE,\n",
    "            act_names=act_names,\n",
    "            num_last_tokens=num_last_tokens,\n",
    "        )\n",
    "        np.save(output_harmful_file, harmful_acts.cpu().float().numpy())\n",
    "    \n",
    "    # Normalize and Mean across harmful\n",
    "    harmful_acts_norm = harmful_acts / harmful_acts.norm(dim=-1, keepdim=True)\n",
    "    harmful_acts_norm_mean = harmful_acts_norm.mean(dim = 2)\n",
    "\n",
    "    # Direction will store [layer][mid == 0, post == 1]\n",
    "    directions = {}\n",
    "\n",
    "    # Initialize refusal_dirs storage:\n",
    "    d_model = harmful_acts.shape[-1]\n",
    "    unnormed_refusal_dirs = torch.zeros(model.cfg.n_layers * len(act_names), d_model)\n",
    "\n",
    "    # Used Module Name, each format should be (layer_ind, (0/1))\n",
    "    used_module_names = []\n",
    "    for l in tqdm.tqdm(range(model.cfg.n_layers * len(act_names))):\n",
    "\n",
    "        layer, pos = l // 2, l % 2\n",
    "\n",
    "        # RL: Forward Hook Construction; Separated for layer and position:\n",
    "        fwd_hooks = []\n",
    "        for tup in used_module_names:\n",
    "            ly, mp = tup\n",
    "            # Note that each direction is from src (harmless) to (harmful);\n",
    "            # We add, so we take positive direction\n",
    "            if mp == 0:\n",
    "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_mid\", get_CAA_output_hook(ly, directions[ly][mp])))\n",
    "            elif mp == 1:\n",
    "                fwd_hooks.append((f\"blocks.{ly}.hook_resid_post\", get_CAA_output_hook(ly, directions[ly][mp])))\n",
    "            else:\n",
    "                raise NotImplementedError(\"mp not 0 or 1\")\n",
    "        # fwd_hooks = [(f\"blocks.{ly}.hook_resid_post\", get_direction_ablation_output_hook(ly, directions[ly])) for ly in used_module_names]\n",
    "            # These hooks will intervene on the model, that's why we need hook_* args.\n",
    "        # get contranstive activations\n",
    "        # RL: Apply steering to only Harmless Activation (Src)\n",
    "        with model.hooks(fwd_hooks=fwd_hooks):   \n",
    "            harmless_acts, cache = get_activations(\n",
    "                model,\n",
    "                harmless_inst_train[:N_INST_TRAIN],\n",
    "                batch_size=BATCH_SIZE,\n",
    "                act_names=act_names,\n",
    "                num_last_tokens=num_last_tokens,\n",
    "            )\n",
    "\n",
    "    \n",
    "        # print(harmful_acts.shape)\n",
    "        # print(harmless_acts.shape)\n",
    "\n",
    "        # For each step, we find the difference of normed means (Norm the activations, then take the mean) then we find the refusal direction\n",
    "        # Take the Mean across Batch\n",
    "        harmless_acts_norm = harmless_acts / harmless_acts.norm(dim=-1, keepdim=True)\n",
    "        \n",
    "        # Take Mean across Batch\n",
    "        harmless_acts_norm_mean = harmless_acts_norm.mean(dim = 2)\n",
    "\n",
    "        # Take the difference in normed mean (For now, our ref_dir is from harmless -> harmful; but in our directional ablation we take the opposite direction)\n",
    "        ref_dir_set = harmful_acts_norm_mean - harmless_acts_norm_mean\n",
    "        # Specifically, choose the last token \n",
    "        d_model = ref_dir_set.shape[-1]\n",
    "        ref_dir_set = ref_dir_set[:, :, -1].reshape(-1, d_model)\n",
    "        ref_dir = ref_dir_set[l] # This should be = to layer*2 + pos\n",
    "\n",
    "        shifted_ref_dir = ref_dir_set.roll(1, dims=0)\n",
    "        shifted_ref_dir[0] = ref_dir_set[0]\n",
    "        der_comp = ref_dir_set - shifted_ref_dir\n",
    "\n",
    "        int_comp = torch.cumsum(ref_dir_set, dim=0)\n",
    "        # seq_harmful_acts_normed = seq_harmful_acts / seq_harmful_acts.norm(dim=-1, keepdim=True)\n",
    "        # seq_harmless_acts_normed = seq_harmless_acts / seq_harmless_acts.norm(dim=-1, keepdim=True)\n",
    "        if pos == 0:\n",
    "            directions[layer] = {}\n",
    "        if momentum_mode:\n",
    "            if l == 0:\n",
    "                v = ref_dir + torch.randn_like(ref_dir) - 0.3\n",
    "            else:\n",
    "                \n",
    "                v = p_coe*ref_dir + i_coe*int_comp[l] + d_coe*der_comp[l] + torch.randn_like(ref_dir) - 0.3\n",
    "            directions[layer][pos] = v\n",
    "            unnormed_refusal_dirs[l, :] = v.detach()\n",
    "        else:\n",
    "            directions[layer][pos] = ref_dir\n",
    "            unnormed_refusal_dirs[l, :] = ref_dir.detach()\n",
    "        # harmful_acts_normed_mean = seq_harmful_acts_normed.mean(dim=2)\n",
    "        # harmless_acts_normed_mean = seq_harmless_acts_normed.mean(dim=2)\n",
    "        used_module_names.append((layer, pos))\n",
    "\n",
    "    # RL: Compute Normed Refusal Directions\n",
    "    refusal_dirs = unnormed_refusal_dirs / unnormed_refusal_dirs.norm(dim=-1, keepdim=True)\n",
    "    refusal_dirs = refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
    "    unnormed_refusal_dirs = unnormed_refusal_dirs.reshape(model.cfg.n_layers, len(act_names), d_model).cpu().float().numpy()\n",
    "\n",
    "    # RL: Set harmless + harmful acts to cpu and float\n",
    "    harmful_acts = harmful_acts.cpu().float()\n",
    "    harmless_acts = harmless_acts.cpu().float()\n",
    "\n",
    "    # Save harmless, refusal directions normed and unnormed\n",
    "    np.save(output_harmless_file, harmless_acts.numpy())\n",
    "    np.save(unnormed_refusal_dirs_path, unnormed_refusal_dirs)\n",
    "    np.save(refusal_dirs_path, refusal_dirs)\n",
    "    np.save(p_dirs_path, ref_dir_set.cpu().float().numpy())\n",
    "    np.save(i_dirs_path, int_comp.cpu().float().numpy())\n",
    "    np.save(d_dirs_path, der_comp.cpu().float().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_dir_set.norm(dim=-1, keepdim=True).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = (ref_dir_set[0,:] @ ref_dir_set.T).cpu().float().numpy()\n",
    "x = (ref_dir_set[0,:] @ int_comp.T).cpu().float().numpy()\n",
    "indices = list(range(x.shape[0]))\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "# First plot: x vs y\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=x,\n",
    "    y=y,\n",
    "    mode='lines+markers',\n",
    "    name=\"<e(0), s(t)> vs <e(0), e(t)>\",\n",
    "    text=indices,\n",
    "    hovertemplate=\"Index: %{text}<br>x: %{x}<br>y: %{y}<extra></extra>\",\n",
    "    marker=dict(size=8, color='blue', opacity=0.7),\n",
    "    line=dict(color='lightblue', width=2),\n",
    "    xaxis=\"x1\",\n",
    "    yaxis=\"y1\"\n",
    "))\n",
    "\n",
    "# Second plot: index vs y\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=indices,\n",
    "    y=y,\n",
    "    mode='lines+markers',\n",
    "    name=\"Index vs <e(0), e(t)>\",\n",
    "    text=indices,\n",
    "    hovertemplate=\"Index: %{text}<br>Index: %{x}<br>y: %{y}<extra></extra>\",\n",
    "    marker=dict(size=8, color='red', opacity=0.7),\n",
    "    line=dict(color='pink', width=2),\n",
    "    xaxis=\"x2\",\n",
    "    yaxis=\"y2\"\n",
    "))\n",
    "\n",
    "# Define two independent axes\n",
    "fig.update_layout(\n",
    "    width=1200,\n",
    "    height=600,\n",
    "    showlegend=True,\n",
    "    xaxis=dict(domain=[0, 0.45], title=\"<e(0), s(t)>\"),\n",
    "    yaxis=dict(title=\"<e(0), e(t)>\"),\n",
    "    xaxis2=dict(domain=[0.55, 1.0], title=\"Index\"),\n",
    "    yaxis2=dict(title=\"<e(0), e(t)>\", anchor=\"x2\")\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def median_of_means(tensor: torch.Tensor, num_blocks: int, dim: int) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Compute the Median of Means along a specified dimension.\n",
    "\n",
    "    Args:\n",
    "        tensor (torch.Tensor): Input tensor of any shape.\n",
    "        num_blocks (int): Number of blocks to divide the data along `dim`.\n",
    "        dim (int): The dimension along which to compute the MoM.\n",
    "\n",
    "    Returns:\n",
    "        torch.Tensor: Tensor with same shape as input, but reduced over `dim`.\n",
    "    \"\"\"\n",
    "    N = tensor.size(dim)\n",
    "    if num_blocks <= 0 or num_blocks > N:\n",
    "        raise ValueError(f\"num_blocks must be between 1 and {N}\")\n",
    "\n",
    "    # Move the target dimension to the front\n",
    "    perm = [dim] + [i for i in range(tensor.ndim) if i != dim]\n",
    "    tensor_perm = tensor.permute(perm)  # Now shape (N, ...)\n",
    "\n",
    "    shape = tensor_perm.shape\n",
    "    tensor_flat = tensor_perm.reshape(N, -1)  # Flatten rest: (N, M)\n",
    "\n",
    "    # Shuffle along N\n",
    "    idx = torch.randperm(N, device=tensor.device)\n",
    "    shuffled = tensor_flat[idx]\n",
    "\n",
    "    # Determine block sizes\n",
    "    block_sizes = [N // num_blocks + (1 if i < (N % num_blocks) else 0) for i in range(num_blocks)]\n",
    "\n",
    "    block_means = []\n",
    "    start = 0\n",
    "    for size in block_sizes:\n",
    "        end = start + size\n",
    "        block = shuffled[start:end]  # (block_size, M)\n",
    "        block_mean = block.mean(dim=0)  # (M,)\n",
    "        block_means.append(block_mean)\n",
    "        start = end\n",
    "\n",
    "    block_means_tensor = torch.stack(block_means, dim=0)  # (num_blocks, M)\n",
    "    mom_flat = block_means_tensor.median(dim=0).values  # (M,)\n",
    "\n",
    "    # Reshape back to shape without the reduced dim\n",
    "    reduced_shape = [tensor.shape[i] for i in range(tensor.ndim) if i != dim]\n",
    "    mom = mom_flat.view(*reduced_shape)\n",
    "\n",
    "    return mom\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyze the activations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tqD5E8Vc_w5d"
   },
   "outputs": [],
   "source": [
    "from torch.nn.functional import cosine_similarity, normalize\n",
    "\n",
    "\n",
    "def get_pairwise_cosine_similarity(acts_normed):\n",
    "    # comput cosine similarity of each pair of vector from a set of normalized vectors\n",
    "    # acts_normed is ... x batch x toks x dim\n",
    "\n",
    "    acts_normed = torch.tensor(acts_normed, device=\"cuda:0\")\n",
    "\n",
    "    # ... batch1 toks dim, ... batch2 toks dim -> ... toks batch1 batch2\n",
    "    acts_pairwise_sim = torch.einsum(\"...ikl,...jkl->...kij\", acts_normed, acts_normed)\n",
    "\n",
    "    batch_size = acts_pairwise_sim.shape[-1]\n",
    "\n",
    "    # get the indices of the upper triangular part of the batch x batch similarity matrix\n",
    "    indices = np.arange(batch_size**2).reshape(batch_size, batch_size)\n",
    "    indices = indices[np.triu_indices_from(indices, k=1)]\n",
    "\n",
    "    # ... x toks x batch x (batch * batch)\n",
    "    acts_pairwise_sim = acts_pairwise_sim.reshape(*acts_pairwise_sim.shape[:-2], -1)\n",
    "    # ... x toks x batch x (batch * (batch - 1) // 2)\n",
    "    acts_pairwise_sim = acts_pairwise_sim[..., indices]\n",
    "    # ... x (batch * (batch - 1) // 2) x toks\n",
    "    acts_pairwise_sim = acts_pairwise_sim.swapaxes(-1, -2)\n",
    "\n",
    "    return acts_pairwise_sim\n",
    "\n",
    "\n",
    "def get_cosine_with_mean(acts_normed):\n",
    "    # compute cosine similarity of each vector with the mean vector\n",
    "    # acts_normed is ... x batch x toks x dim\n",
    "\n",
    "    acts_normed = torch.tensor(acts_normed, device=\"cuda:0\")\n",
    "    mean_act = acts_normed.mean(axis=2)\n",
    "    mean_act /= mean_act.norm(dim=-1, keepdim=True)\n",
    "\n",
    "    # ... batch toks dim, ... toks dim -> ... batch toks\n",
    "    cosine_with_mean = torch.einsum(\"...ijk,...jk ->...ij\", acts_normed, mean_act)\n",
    "\n",
    "    return cosine_with_mean\n",
    "\n",
    "\n",
    "# layers x resid_modules x batch x tokens x dim\n",
    "harmful_acts_normed = harmful_acts / harmful_acts.norm(dim=-1, keepdim=True)\n",
    "harmless_acts_normed = harmless_acts / harmless_acts.norm(dim=-1, keepdim=True)\n",
    "\n",
    "# shape: layers x resid_modules x tokens x dim\n",
    "# normalize then get mean because the activation will be normalized by the RMSNorm layer\n",
    "# normalize helps to preserve the directions as the magnitudes are irrelevant after the\n",
    "# RMSNorm layer\n",
    "\n",
    "print(harmful_acts_normed.shape)\n",
    "harmful_acts_normed_mean = harmful_acts_normed.mean(dim=2)\n",
    "harmless_acts_normed_mean = harmless_acts_normed.mean(dim=2)\n",
    "\n",
    "harmful_acts_normed_mom = median_of_means(harmful_acts_normed, num_blocks=8, dim=2)\n",
    "harmless_acts_normed_mom = median_of_means(harmless_acts_normed, num_blocks=8, dim=2)\n",
    "# layers x resid_modules x tokens\n",
    "similarity_scores = (\n",
    "    cosine_similarity(harmful_acts_normed_mean, harmless_acts_normed_mean, dim=-1)\n",
    "    .cpu()\n",
    "    .float()\n",
    "    .numpy()\n",
    ")\n",
    "\n",
    "\n",
    "hidden_dim = harmful_acts.shape[-1]\n",
    "# shape: layers x resid_modules x tokens x dim\n",
    "# Rescale the activations to the same as in RMSNorm (sqrt(hidden_dim))\n",
    "# This effectively makes the values of each vector be standard normal\n",
    "# So regardless of the hidden dimension, each vector will always be a sample from\n",
    "# standard normal\n",
    "# Hence the variance of activation values will be 1\n",
    "harmful_acts_normed_var = (\n",
    "    (harmful_acts_normed * np.sqrt(hidden_dim)).var(dim=2).cpu().float().numpy()\n",
    ")\n",
    "harmless_acts_normed_var = (\n",
    "    (harmless_acts_normed * np.sqrt(hidden_dim)).var(dim=2).cpu().float().numpy()\n",
    ")\n",
    "\n",
    "harmful_acts_normed = harmful_acts_normed.cpu().float().numpy()\n",
    "harmless_acts_normed = harmless_acts_normed.cpu().float().numpy()\n",
    "\n",
    "\n",
    "# layers x resid_modules x batch x tokens\n",
    "# cosine of each vector with the mean vector\n",
    "harmful_acts_cosine_with_mean = get_cosine_with_mean(harmful_acts_normed).cpu().numpy()\n",
    "harmless_acts_cosine_with_mean = (\n",
    "    get_cosine_with_mean(harmless_acts_normed).cpu().numpy()\n",
    ")\n",
    "\n",
    "# layers x resid_modules x (batch * (batch - 1) // 2) x tokens\n",
    "# cosine similarity of each pair of vectors\n",
    "harmful_acts_pairwise_sim = (\n",
    "    get_pairwise_cosine_similarity(harmful_acts_normed).cpu().numpy()\n",
    ")\n",
    "harmless_acts_pairwise_sim = (\n",
    "    get_pairwise_cosine_similarity(harmless_acts_normed).cpu().numpy()\n",
    ")\n",
    "\n",
    "# layers x resid_modules x tokens\n",
    "# variance of cosine similarity of each pair of vectors\n",
    "harmful_acts_pairwise_sim_var = np.var(harmful_acts_pairwise_sim, axis=-2)\n",
    "harmless_acts_pairwise_sim_var = np.var(harmless_acts_pairwise_sim, axis=-2)\n",
    "\n",
    "acts_normed_var = dict()\n",
    "\n",
    "# layers x resid_modules x tokens\n",
    "acts_normed_var[\"harmful\"] = dict(\n",
    "    mean=harmful_acts_normed_var.mean(axis=-1),\n",
    "    max=harmful_acts_normed_var.max(axis=-1),\n",
    ")\n",
    "\n",
    "# layers x resid_modules x tokens\n",
    "acts_normed_var[\"harmless\"] = dict(\n",
    "    mean=harmless_acts_normed_var.mean(axis=-1),\n",
    "    max=harmless_acts_normed_var.max(axis=-1),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean up memory\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the activations\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cosine Similarity between harmful and harmless activations at each layer and token position\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
    "data = similarity_scores.reshape(-1, similarity_scores.shape[-1])\n",
    "y_labels = sum([[f\"{layer}-mid\", f\"{layer}-post\"] for layer in range(num_layers)], [])\n",
    "x_labels = [repr(tok) for tok in template_suffix_toks]\n",
    "\n",
    "\n",
    "fig = px.imshow(\n",
    "    data,\n",
    "    y=y_labels,\n",
    "    labels={\"x\": \"token position\", \"y\": \"layer\", \"color\": \"cosine similarity\"},\n",
    "    aspect=\"auto\",\n",
    ")\n",
    "fig.update_layout(\n",
    "    xaxis={\n",
    "        \"tickmode\": \"array\",\n",
    "        \"ticktext\": x_labels,\n",
    "        \"tickvals\": list(range(len(x_labels))),\n",
    "    },\n",
    "    yaxis={\n",
    "        \"tickmode\": \"array\",\n",
    "        \"ticktext\": list(range(len(y_labels))),\n",
    "        \"tickvals\": list(range(0, len(y_labels), len(act_names))),\n",
    "    },\n",
    "    title=(\n",
    "        \"Cosine Similarity between harmful and harmless activations at each layer and\"\n",
    "        \" token position\"\n",
    "    ),\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def variance_plot(**kwargs):\n",
    "    x = kwargs.pop(\"x\")\n",
    "    y = kwargs.pop(\"y\")\n",
    "    y_mean = y.mean(dim=-1)\n",
    "    y_std = y.std(dim=-1)\n",
    "    y_upper = y_mean + y_std\n",
    "    y_lower = y_mean - y_std\n",
    "    y_upper = y_upper.tolist()\n",
    "    y_lower = y_lower.tolist()\n",
    "    # colour = kwargs.pop(\"color\")\n",
    "\n",
    "    trace = go.Scatter(\n",
    "        x=x + x[::-1],\n",
    "        y=y_upper + y_lower[::-1],\n",
    "        mode=\"lines\",\n",
    "        fill=\"toself\",\n",
    "        line=dict(color=kwargs[\"fillcolor\"], width=0),\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    return trace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Activation norms at each extraction point\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
    "\n",
    "chosen_token = -1\n",
    "colour_map = {\n",
    "    \"harmless\": plotly.colors.qualitative.Plotly[0],\n",
    "    \"harmful\": plotly.colors.qualitative.Plotly[1],\n",
    "    \"neutral\": plotly.colors.qualitative.Plotly[3],\n",
    "}\n",
    "colour_map_light = {\n",
    "    \"harmless\": plotly.colors.qualitative.Pastel1[1],\n",
    "    \"harmful\": plotly.colors.qualitative.Pastel1[0],\n",
    "    \"neutral\": plotly.colors.qualitative.Pastel1[3],\n",
    "}\n",
    "colour_map_opaque = {\n",
    "    \"harmless\": None,\n",
    "    \"harmful\": \"rgba(251, 180, 174, 0.3)\",\n",
    "    \"harmless\": \"rgba(179, 205, 227, 0.3)\",\n",
    "}\n",
    "\n",
    "# layers x resid_modules x tokens x batch x dim\n",
    "acts = {\"harmful\": harmful_acts, \"harmless\": harmless_acts}\n",
    "\n",
    "categories = [\"harmless\", \"harmful\"]\n",
    "resid_modules = [\"mid\", \"post\"]\n",
    "\n",
    "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
    "x_values = [str(i) for i in range(2 * num_layers)]\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "for category in categories:\n",
    "    normed_acts = acts[category].norm(dim=-1)\n",
    "    mean_normed_acts = normed_acts.mean(dim=-1)\n",
    "\n",
    "    y_values = mean_normed_acts[..., chosen_token].flatten()\n",
    "\n",
    "    # mean\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values,\n",
    "            name=category,\n",
    "            mode=\"lines+markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=True,\n",
    "        )\n",
    "    )\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values,\n",
    "            mode=\"lines+markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map_light[category], size=3),\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # variance\n",
    "    fig.add_trace(\n",
    "        variance_plot(\n",
    "            x=x_values,\n",
    "            y=normed_acts[:, :, chosen_token].reshape(-1, normed_acts.shape[-1]),\n",
    "            yaxis=\"y\",\n",
    "            fillcolor=colour_map_opaque[category],\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # dot markers\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values[1::],\n",
    "            y=y_values[1::],\n",
    "            name=f\"{category}\",\n",
    "            mode=\"markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "    # for module_idx, module_name in enumerate(resid_modules):\n",
    "    #     if module_name == \"post\":\n",
    "    #         colour = colour_map[category]\n",
    "    #     else:\n",
    "    #         colour = colour_map_light[category]\n",
    "    #     fig.add_trace(\n",
    "    #         go.Scatter(\n",
    "    #             x=x_values[module_idx::2],\n",
    "    #             y=y_values[module_idx::2],\n",
    "    #             name=f\"{category}-{module_name}\",\n",
    "    #             mode=\"markers\",\n",
    "    #             yaxis=\"y\",\n",
    "    #             marker=dict(color=colour, size=3),\n",
    "    #             showlegend=True,\n",
    "    #         )\n",
    "    #     )\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    # title=f\"Activation norms at each layer for {MODEL_PATH}\",\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=1, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\",\n",
    "        dtick=4,\n",
    "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=dict(text=\"Activation Norm\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        zeroline=False,\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    hovermode=\"x unified\",\n",
    "    height=250,\n",
    "    # width=20 + 12 * len(x_values),\n",
    "    width=600,\n",
    "    margin=dict(l=0, r=0, t=0, b=0),\n",
    "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"acts_norm.pdf\", scale=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Variance of normed activations at each extraction point\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_layers, num_act_modules, num_tokens = similarity_scores.shape\n",
    "\n",
    "chosen_token = -1\n",
    "colour_map = {\n",
    "    \"harmless\": plotly.colors.qualitative.Plotly[0],\n",
    "    \"harmful\": plotly.colors.qualitative.Plotly[1],\n",
    "    \"neutral\": plotly.colors.qualitative.Plotly[3],\n",
    "}\n",
    "colour_map_light = {\n",
    "    \"harmless\": plotly.colors.qualitative.Pastel1[1],\n",
    "    \"harmful\": plotly.colors.qualitative.Pastel1[0],\n",
    "    \"neutral\": plotly.colors.qualitative.Pastel1[3],\n",
    "}\n",
    "\n",
    "categories = [\"harmless\", \"harmful\"]\n",
    "resid_modules = [\"mid\", \"post\"]\n",
    "metrics = [\"mean\", \"max\"]\n",
    "\n",
    "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "for m, metric in enumerate(metrics):\n",
    "    for category in categories:\n",
    "        y_values = acts_normed_var[category][metric][..., chosen_token]\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x=x_values,\n",
    "                y=y_values.flatten(),\n",
    "                mode=\"lines\",\n",
    "                yaxis=f\"y{m + 1}\",\n",
    "                marker=dict(color=colour_map_light[category], size=5),\n",
    "                showlegend=False,\n",
    "            )\n",
    "        )\n",
    "        for module_idx, module_name in enumerate(resid_modules):\n",
    "            if module_name == \"mid\":\n",
    "                colour = colour_map[category]\n",
    "            else:\n",
    "                colour = colour_map_light[category]\n",
    "            fig.add_trace(\n",
    "                go.Scatter(\n",
    "                    x=x_values[module_idx::2],\n",
    "                    y=y_values.flatten()[module_idx::2],\n",
    "                    name=f\"{category}-{module_name}\",\n",
    "                    mode=\"markers\",\n",
    "                    yaxis=f\"y{m + 1}\",\n",
    "                    marker=dict(color=colour, size=5),\n",
    "                    showlegend=m == 0,\n",
    "                )\n",
    "            )\n",
    "\n",
    "diff_mean_var = (\n",
    "    acts_normed_var[\"harmless\"][\"mean\"][..., chosen_token]\n",
    "    - acts_normed_var[\"harmful\"][\"mean\"][..., chosen_token]\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=x_values,\n",
    "        y=y_values.flatten(),\n",
    "        mode=\"lines\",\n",
    "        yaxis=\"y3\",\n",
    "        marker=dict(color=colour_map_light[\"neutral\"], size=5),\n",
    "        showlegend=False,\n",
    "    )\n",
    ")\n",
    "for module_idx, module_name in enumerate(resid_modules):\n",
    "    if module_name == \"mid\":\n",
    "        colour = colour_map[\"neutral\"]\n",
    "    else:\n",
    "        colour = colour_map_light[\"neutral\"]\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values[module_idx::2],\n",
    "            y=y_values.flatten()[module_idx::2],\n",
    "            name=f\"(harmelss - harmful)-{module_name}\",\n",
    "            mode=\"markers\",\n",
    "            yaxis=\"y3\",\n",
    "            marker=dict(color=colour, size=5),\n",
    "            showlegend=True,\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    title=f\"Variance of normed activations at each layer for {MODEL_PATH}\",\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=3, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\", dtick=2, title=\"Transformers Block\", gridcolor=\"lightgrey\"\n",
    "    ),\n",
    "    yaxis=dict(title=f\"{metrics[0]} variance\", gridcolor=\"lightgrey\", zeroline=False),\n",
    "    yaxis2=dict(title=f\"{metrics[1]} variance\", gridcolor=\"lightgrey\", zeroline=False),\n",
    "    yaxis3=dict(title=f\"harmless - harmful\", gridcolor=\"lightgrey\", zeroline=False),\n",
    "    hovermode=\"x unified\",\n",
    "    height=1200,\n",
    "    # width=20 + 12 * len(x_values),\n",
    ")\n",
    "\n",
    "# fig.update_xaxes(\n",
    "#     mirror=True,\n",
    "#     ticks='outside',\n",
    "#     showline=True,\n",
    "#     # linecolor='black',\n",
    "#     gridcolor='lightgrey'\n",
    "# )\n",
    "# fig.update_yaxes(\n",
    "#     mirror=True,\n",
    "#     ticks='outside',\n",
    "#     showline=True,\n",
    "#     # linecolor='black',\n",
    "#     gridcolor='lightgrey'\n",
    "# )\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Statistics of activations between harmful and harmless activations at each extraction point\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "# Angles between each pair of mean vectors at each layer\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=x_values,\n",
    "        y=np.arccos(similarity_scores[..., chosen_token]).flatten(),\n",
    "        name=\"harmful-harmless angle\",\n",
    "        # showlegend=False,\n",
    "        mode=\"lines+markers\",\n",
    "        marker=dict(color=colour_map_light[\"neutral\"], size=5),\n",
    "        yaxis=\"y\",\n",
    "    )\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=x_values[1::2],\n",
    "        y=np.arccos(similarity_scores[..., 1, chosen_token]).flatten(),\n",
    "        # name=\"cosine similarity\",\n",
    "        showlegend=False,\n",
    "        mode=\"markers\",\n",
    "        marker=dict(color=colour_map[\"neutral\"], size=5),\n",
    "        yaxis=\"y\",\n",
    "    )\n",
    ")\n",
    "\n",
    "for category in categories:\n",
    "    if category == \"harmful\":\n",
    "        # layers x resid_modules x (batch * (batch - 1) // 2)\n",
    "        # cosine of each sample activation vector with other sample activation vectors\n",
    "        acts_pairwise_sim = harmful_acts_pairwise_sim[..., chosen_token]\n",
    "\n",
    "        # layers x resid_modules x batch\n",
    "        # cosine of each sample activation vector with the mean activation vector\n",
    "        acts_cosine_with_mean = harmful_acts_cosine_with_mean[..., chosen_token].clip(\n",
    "            -1, 1\n",
    "        )\n",
    "    else:\n",
    "        # layers x resid_modules x (batch * (batch - 1) // 2)\n",
    "        # cosine of each sample activation vector with other sample activation vectors\n",
    "        acts_pairwise_sim = harmless_acts_pairwise_sim[..., chosen_token]\n",
    "\n",
    "        # layers x resid_modules x batch\n",
    "        # cosine of each sample activation vector with the mean activation vector\n",
    "        acts_cosine_with_mean = harmless_acts_cosine_with_mean[..., chosen_token].clip(\n",
    "            -1, 1\n",
    "        )\n",
    "\n",
    "    acts_arccos_with_mean = np.arccos(acts_cosine_with_mean)\n",
    "\n",
    "    count = acts_cosine_with_mean.shape[-1]\n",
    "\n",
    "    for module_idx, module_name in enumerate(resid_modules):\n",
    "        if module_name == \"pre\":\n",
    "            colour = colour_map[category]\n",
    "        else:\n",
    "            colour = colour_map_light[category]\n",
    "\n",
    "        # fig.add_trace(\n",
    "        #     go.Box(\n",
    "        #         x=sum([[name] * count for name in x_values[module_idx::2]], []),\n",
    "        #         y=np.arccos(acts_arccos_with_mean[..., module_idx, :]).flatten(),\n",
    "        #         boxmean=True,\n",
    "        #         line_width=1,\n",
    "        #         marker_size=2,\n",
    "        #         showlegend=False,\n",
    "        #         marker_color=colour,\n",
    "        #         yaxis=\"y2\",\n",
    "        #     ),\n",
    "        # )\n",
    "        fig.add_trace(\n",
    "            variance_plot(\n",
    "                x=x_values,\n",
    "                y=torch.tensor(acts_arccos_with_mean).reshape(\n",
    "                    -1, acts_arccos_with_mean.shape[-1]\n",
    "                ),\n",
    "                yaxis=\"y2\",\n",
    "                fillcolor=colour_map_opaque[category],\n",
    "                showlegend=False,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=acts_arccos_with_mean.mean(axis=-1).flatten(),\n",
    "            mode=\"lines+markers\",\n",
    "            showlegend=False,\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            line_width=1,\n",
    "            yaxis=\"y2\",\n",
    "        ),\n",
    "    )\n",
    "    # fig.add_trace(\n",
    "    #     go.Scatter(\n",
    "    #         x=x_values[1::2],\n",
    "    #         y=acts_pairwise_sim.mean(axis=-1).flatten(),\n",
    "    #         mode=\"markers\",\n",
    "    #         showlegend=False,\n",
    "    #         marker_color=colour_map[category],\n",
    "    #         yaxis=\"y4\",\n",
    "    #     ),\n",
    "    # )\n",
    "\n",
    "\n",
    "harmful_locality_scores = np.arccos(\n",
    "    harmful_acts_cosine_with_mean[..., chosen_token].clip(-1, 1)\n",
    ").mean(axis=-1)\n",
    "harmless_locality_scores = np.arccos(\n",
    "    harmless_acts_cosine_with_mean[..., chosen_token].clip(-1, 1)\n",
    ").mean(axis=-1)\n",
    "locality_scores = np.maximum(harmful_locality_scores, harmless_locality_scores)\n",
    "\n",
    "sparsity_scores = np.arccos(similarity_scores[..., chosen_token].clip(-1, 1))\n",
    "\n",
    "# scores = sparsity_scores * 2 / (harmful_locality_scores + harmless_locality_scores)\n",
    "# scores = sparsity_scores - locality_scores\n",
    "scores = sparsity_scores / locality_scores\n",
    "# scores[\n",
    "#     sparsity_scores <= np.minimum(harmful_locality_scores, harmless_locality_scores)\n",
    "# ] = np.nan\n",
    "fig.add_trace(\n",
    "    go.Scatter(x=x_values, y=scores.flatten(), mode=\"lines+markers\", yaxis=\"y3\")\n",
    ")\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    title=(\n",
    "        \"Statistics of activations between harmful and harmless activations at each\"\n",
    "        f\" layer for {MODEL_PATH}\"\n",
    "    ),\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=3, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\", title=\"Transformers Block\", dtick=2, gridcolor=\"lightgrey\"\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=\"Harmful-Harmless angle (1)\", gridcolor=\"lightgrey\", zeroline=False\n",
    "    ),\n",
    "    yaxis2=dict(title=\"Pairwise arccos (2)\", gridcolor=\"lightgrey\", zeroline=False),\n",
    "    yaxis3=dict(title=\"(1) / max((2))\", gridcolor=\"lightgrey\", zeroline=False),\n",
    "    # yaxis4=dict(title=\"Pairwise Cosine Similarity between samples\"),\n",
    "    hovermode=\"x unified\",\n",
    "    height=1200,\n",
    "    # width=20 + 12 * len(x_values),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "# an adhoc attempt to find the best direction\n",
    "chosen_layer, chosen_act_idx = np.unravel_index(\n",
    "    np.nanargmax(scores, axis=None), scores.shape\n",
    ")\n",
    "print(\n",
    "    f\"Best direction at layer {chosen_layer}, module\"\n",
    "    f\" {act_names[chosen_act_idx]}, position {chosen_token}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# another adhoc attempt to find the best direction\n",
    "\n",
    "harmful_locality_scores = np.arccos(harmful_acts_cosine_with_mean.clip(-1, 1)).mean(\n",
    "    axis=-2\n",
    ")\n",
    "harmless_locality_scores = np.arccos(harmless_acts_cosine_with_mean.clip(-1, 1)).mean(\n",
    "    axis=-2\n",
    ")\n",
    "locality_scores = harmful_locality_scores + harmless_locality_scores\n",
    "\n",
    "sparsity_scores = np.arccos(similarity_scores.clip(-1, 1))\n",
    "\n",
    "scores = sparsity_scores / np.maximum(harmful_locality_scores, harmless_locality_scores)\n",
    "print(scores.shape)\n",
    "# scores[\n",
    "#     sparsity_scores <= np.minimum(harmful_locality_scores, harmless_locality_scores)\n",
    "# ] = np.nan\n",
    "scores = scores[..., -2:]\n",
    "_chosen_layer, _chosen_act_idx, _chosen_token = np.unravel_index(\n",
    "    np.nanargmax(scores, axis=None), scores.shape\n",
    ")\n",
    "print(\n",
    "    f\"Lowest cosine similarity at layer {_chosen_layer}, module\"\n",
    "    f\" {act_names[_chosen_act_idx]}, position {_chosen_token}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store the candidate directions from every extraction point\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def PID_control(candidates: torch.Tensor, kp, ki, kd):\n",
    "    num_cands = candidates.shape[0]\n",
    "    new_candidates = torch.zeros_like(candidates, device=candidates.device)\n",
    "    new_candidates[0] = candidates[0]\n",
    "    for i in range(1, num_cands):\n",
    "        new_candidates[i] = kp*candidates[i] + ki*(torch.sum(new_candidates[:i], dim=0) + candidates[i]) + kd*(candidates[i] - new_candidates[i-1])\n",
    "    \n",
    "    return new_candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def RePE(candiates):\n",
    "    # 1. Center each batch along samples\n",
    "    candiates_c = candiates - np.mean(candiates, axis=2, keepdims=True) \n",
    "\n",
    "    # 2. Batched SVD\n",
    "    # torch.linalg.svd supports batching\n",
    "    _, _, Vh = np.linalg.svd(candiates_c, full_matrices=False)\n",
    "\n",
    "    # 3. First principal component for each batch\n",
    "    v1 = Vh[:, :, 0, :]   # shape: (B, D)\n",
    "    return v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ITI(X, y, num_layers):\n",
    "\n",
    "    # --- model factory ---\n",
    "    class LogisticLR(nn.Module):\n",
    "        def __init__(self, in_dim):\n",
    "            super().__init__()\n",
    "            self.linear = nn.Linear(in_dim, 1, bias=False)  # logistic regression = linear + sigmoid\n",
    "        def forward(self, x):\n",
    "            return self.linear(x).squeeze(-1)   # logits\n",
    "\n",
    "    def train_one(model, loader, epochs=5, lr=1e-2, device=\"cuda\" if torch.cuda.is_available() else \"cpu\"):\n",
    "        model.to(device)\n",
    "        opt  = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "        loss = nn.BCEWithLogitsLoss()\n",
    "        model.train()\n",
    "        l=10\n",
    "        for e in range(epochs):\n",
    "            for xb, yb in loader:\n",
    "                xb, yb = xb.to(device), yb.to(device)\n",
    "                logits = model(xb)\n",
    "                l = loss(logits, yb)\n",
    "                opt.zero_grad(); l.backward(); opt.step()\n",
    "        return model, l\n",
    "\n",
    "    # --- initialize & train 32 models ---\n",
    "    models = [LogisticLR(X.shape[-1]) for _ in range(num_layers)]\n",
    "    for i, m in enumerate(models):\n",
    "        print(f'==== Training Logistic Classifier for Layer {i} ====')\n",
    "        torch.manual_seed(1234 + i)  # different init per model\n",
    "        ds = TensorDataset(X[i], y)\n",
    "        loader = DataLoader(ds, batch_size=64, shuffle=True)\n",
    "        _, loss = train_one(m, loader, epochs=20)\n",
    "        print(loss)\n",
    "    return torch.stack([m.linear.weight.data.squeeze(0) for m in models])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_direction_ablation_output_hook(layer, direction: Tensor):\n",
    "    def hook_fn(output, hook):\n",
    "        nonlocal direction\n",
    "\n",
    "        if isinstance(output, tuple):\n",
    "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output[0]\n",
    "        else:\n",
    "            activation: Float[Tensor, \"batch_size seq_len d_model\"] = output\n",
    "\n",
    "        direction = direction / (direction.norm(dim=-1, keepdim=True) + 1e-8)\n",
    "        direction = direction.to(activation)\n",
    "        print(direction.shape)\n",
    "        print(activation.shape)\n",
    "        # activation -= 2*(activation @ direction).unsqueeze(-1) * direction\n",
    "        activation -= direction\n",
    "\n",
    "        if isinstance(output, tuple):\n",
    "            return (activation, *output[1:])\n",
    "        else:\n",
    "            return activation\n",
    "\n",
    "    return hook_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Refusal direction analysis\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pairwise cosine of refusal directions at each layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "layer_names = sum([[f\"{i}-mid\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
    "\n",
    "dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
    "A = dirs @ dirs.T\n",
    "\n",
    "fig = px.imshow(\n",
    "    # np.rad2deg(np.arccos(np.clip(A, 0.0, 1.0))),\n",
    "    A,\n",
    "    x=layer_names,\n",
    "    y=layer_names,\n",
    "    width=len(layer_names) * 14,\n",
    "    height=len(layer_names) * 14,\n",
    "    title=\"Cosine Similarity Matrix\",\n",
    "    color_continuous_scale=\"Viridis\",\n",
    ")\n",
    "fig.update_layout(\n",
    "    yaxis=dict(dtick=1),\n",
    "    xaxis=dict(dtick=1),\n",
    ")\n",
    "# fig.update_traces(xgap=1, ygap=1)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean cosine of refusal directions at each layer with at other layers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import normalize\n",
    "\n",
    "layer_names = sum([[f\"{i}\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
    "layer_names = [str(i) for i in range(2 * num_layers)]\n",
    "\n",
    "harmful_acts_normed_mean_normed = normalize(\n",
    "    harmful_acts_normed_mean[:, :, chosen_token], dim=-1\n",
    ")\n",
    "harmless_acts_normed_mean_normed = normalize(\n",
    "    harmless_acts_normed_mean[:, :, chosen_token], dim=-1\n",
    ")\n",
    "raw_dirs = harmful_acts_normed_mean_normed - harmless_acts_normed_mean_normed\n",
    "\n",
    "raw_dirs = raw_dirs.reshape((-1, raw_dirs.shape[-1]))\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=layer_names,\n",
    "        y=raw_dirs.norm(dim=-1),\n",
    "        mode=\"lines+markers\",\n",
    "        yaxis=\"y\",\n",
    "        marker_color=colour_map_light[\"neutral\"],\n",
    "        marker_size=8,\n",
    "        showlegend=False,\n",
    "    )\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=layer_names[::],\n",
    "        y=raw_dirs.norm(dim=-1)[::],\n",
    "        mode=\"markers\",\n",
    "        yaxis=\"y\",\n",
    "        marker_color=colour_map[\"neutral\"],\n",
    "        marker_size=8,\n",
    "        showlegend=False,\n",
    "    )\n",
    ")\n",
    "\n",
    "print(layer_names[np.argmax(raw_dirs.norm(dim=-1)[:-1])])\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    # title=(\n",
    "    #     \"Statistics of refusal direction candidates at each layer\"\n",
    "    #     f\" layer for {MODEL_PATH}\"\n",
    "    # ),\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=1, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\",\n",
    "        title=dict(text=\"Extraction Point\", font=dict(size=28)),\n",
    "        dtick=4,\n",
    "        gridcolor=\"lightgrey\",\n",
    "        tickfont=dict(size=24),\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=dict(text=\"Norm of<br>Refusal Direction\", font=dict(size=28)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        zeroline=False,\n",
    "        tickfont=dict(size=24),\n",
    "    ),\n",
    "    hovermode=\"x unified\",\n",
    "    height=300,\n",
    "    width=1000,\n",
    "    # width=20 + 12 * len(x_values),\n",
    "    margin=dict(l=20, r=20, t=20, b=20),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"norm_refusal.pdf\", scale=5)\n",
    "\n",
    "\n",
    "flatten_dirs = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
    "pairwise_cosine = flatten_dirs @ flatten_dirs.T\n",
    "# pairwise_cosine = np.arccos(pairwise_cosine)\n",
    "mean_cosine = np.nanmean(pairwise_cosine, axis=-1)\n",
    "\n",
    "fig = go.Figure()\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=layer_names,\n",
    "        y=mean_cosine,\n",
    "        mode=\"lines+markers\",\n",
    "        yaxis=\"y\",\n",
    "        marker_color=colour_map_light[\"neutral\"],\n",
    "        showlegend=False,\n",
    "        marker_size=8,\n",
    "    )\n",
    ")\n",
    "fig.add_trace(\n",
    "    go.Scatter(\n",
    "        x=layer_names[::],\n",
    "        y=mean_cosine[::],\n",
    "        mode=\"markers\",\n",
    "        yaxis=\"y\",\n",
    "        marker_color=colour_map[\"neutral\"],\n",
    "        showlegend=False,\n",
    "        marker_size=8,\n",
    "    )\n",
    ")\n",
    "\n",
    "# fig.add_trace(\n",
    "#     go.Scatter(\n",
    "#         x=layer_names,\n",
    "#         y=raw_dirs.norm(dim=-1) + mean_cosine / mean_cosine.max(),\n",
    "#         mode=\"lines+markers\",\n",
    "#         yaxis=\"y3\",\n",
    "#         marker_color=colour_map_light[\"neutral\"],\n",
    "#         showlegend=False\n",
    "#     )\n",
    "# )\n",
    "\n",
    "fig.update_layout(\n",
    "    # title=(\n",
    "    #     \"Statistics of refusal direction candidates at each extraction point\"\n",
    "    #     f\" for {MODEL_PATH}\"\n",
    "    # ),\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=1, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\",\n",
    "        title=dict(text=\"Extraction Point\", font=dict(size=28)),\n",
    "        dtick=4,\n",
    "        gridcolor=\"lightgrey\",\n",
    "        tickfont=dict(size=24),\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=dict(text=f\"Mean<br>Cosine Score\", font=dict(size=28)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        zeroline=False,\n",
    "        tickfont=dict(size=24),\n",
    "    ),\n",
    "    hovermode=\"x unified\",\n",
    "    height=300,\n",
    "    width=1000,\n",
    "    # width=20 + 12 * len(x_values),\n",
    "    margin=dict(l=20, r=20, t=20, b=20),\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "layer_names[np.nanargmax(mean_cosine)]\n",
    "\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"mean_cosine.pdf\", scale=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sanity check\n",
    "\n",
    "# layers x resid_modules x tokens x batch x dim\n",
    "category2acts_normed = {\n",
    "    \"harmful\": harmful_acts_normed,\n",
    "    \"harmless\": harmless_acts_normed,\n",
    "}\n",
    "\n",
    "print(refusal_dirs.shape)\n",
    "x = category2acts_normed[\"harmful\"][..., chosen_token, :]\n",
    "print(x.shape)\n",
    "\n",
    "a = einops.einsum(\n",
    "    refusal_dirs, x, \"layer act dim, layer act batch dim -> layer act batch\"\n",
    ")\n",
    "# print(a[0][0] - refusal_dirs[0][0] @ x[0][0].T)\n",
    "# assert np.allclose(a[0][0], refusal_dirs[0][0] @ x[0][0].T, atol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scalar projections of activations onto the local refusal direction at each extraction point\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layers x resid_modules x tokens x batch x dim\n",
    "category2acts_normed = {\n",
    "    \"harmful\": harmful_acts_normed,\n",
    "    \"harmless\": harmless_acts_normed,\n",
    "}\n",
    "\n",
    "# x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
    "x_values = [str(i) for i in range(2 * num_layers)]\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "for category in categories:\n",
    "    acts_normed = category2acts_normed[category][:, :, :, chosen_token]\n",
    "    projections = einops.einsum(\n",
    "        refusal_dirs,\n",
    "        acts_normed,\n",
    "        \"layer act dim, layer act batch dim -> layer act batch\",\n",
    "    )\n",
    "    projections = torch.tensor(projections)\n",
    "\n",
    "    mean_projection = projections.mean(dim=-1)\n",
    "\n",
    "    y_values = mean_projection.flatten()\n",
    "\n",
    "    # mean\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values,\n",
    "            name=category,\n",
    "            mode=\"lines+markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=True,\n",
    "        )\n",
    "    )\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values,\n",
    "            name=category,\n",
    "            mode=\"lines+markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map_light[category], size=3),\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # variance\n",
    "    fig.add_trace(\n",
    "        variance_plot(\n",
    "            x=x_values,\n",
    "            y=projections.reshape(-1, projections.shape[-1]),\n",
    "            yaxis=\"y\",\n",
    "            fillcolor=colour_map_opaque[category],\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # dot markers\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values[1::],\n",
    "            y=y_values[1::],\n",
    "            name=f\"{category}\",\n",
    "            mode=\"markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    # title=f\"Scalar projections of activations onto the local refusal direction at each\n",
    "    # extraction point for {MODEL_PATH}\",\n",
    "    plot_bgcolor=\"white\",\n",
    "    grid=dict(rows=1, columns=1),\n",
    "    xaxis=dict(\n",
    "        type=\"category\",\n",
    "        dtick=4,\n",
    "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=dict(text=\"Scalar Projections\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        zeroline=False,\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    hovermode=\"x unified\",\n",
    "    height=250,\n",
    "    # width=20 + 12 * len(x_values),\n",
    "    width=600,\n",
    "    margin=dict(l=0, r=0, t=0, b=0),\n",
    "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"prj_onto_local_refusal_candidates.pdf\", scale=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Criteria for selecting the refusal direction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Criteria: highest norm\n",
    "criteria = raw_dirs.norm(dim=-1)[:-1]\n",
    "\n",
    "argmax = np.nanargmax(criteria)\n",
    "max_norm_layer = argmax // 2\n",
    "max_norm_act_idx = argmax % 2\n",
    "\n",
    "print(\n",
    "    f\"Highest refusal direction norm at layer {max_norm_layer}, module\"\n",
    "    f\" {act_names[max_norm_act_idx]}, position {chosen_token}\"\n",
    ")\n",
    "\n",
    "# Criteria: High similiarity\n",
    "argmax = np.nanargmax(mean_cosine)\n",
    "max_mean_cosine_layer = argmax // 2\n",
    "max_mean_cosine_act_idx = argmax % 2\n",
    "\n",
    "print(\n",
    "    f\"Highest cosine similarity at layer {max_mean_cosine_layer}, module\"\n",
    "    f\" {act_names[max_mean_cosine_act_idx]}, position {chosen_token}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Selecting the refusal direction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chosen_layer = max_mean_cosine_layer\n",
    "chosen_act_idx = max_mean_cosine_act_idx\n",
    "chosen_token = -1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Projection of activation at each extraction poin onto the chosen refusal direction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = go.Figure()\n",
    "\n",
    "for category in [\"harmful\", \"harmless\"]:\n",
    "    if category == \"harmful\":\n",
    "        # acts_normed = harmful_acts.cpu().float().numpy()\n",
    "        acts_normed = harmful_acts_normed\n",
    "    else:\n",
    "        # acts_normed = harmless_acts.cpu().float().numpy()\n",
    "        acts_normed = harmless_acts_normed\n",
    "\n",
    "    # layers x resid_modules x batch_size x dim\n",
    "    activations = acts_normed[..., chosen_token, :].copy()\n",
    "\n",
    "    # dim\n",
    "    direction = refusal_dirs[chosen_layer, chosen_act_idx].copy()\n",
    "\n",
    "    # layers x resid_modules x batch_size\n",
    "    scalar_projections = einops.einsum(\n",
    "        activations,\n",
    "        direction,\n",
    "        \"... batch_size dim, ... dim -> ... batch_size\",\n",
    "    )\n",
    "    scalar_projections = np.nan_to_num(scalar_projections)\n",
    "    print(category)\n",
    "    print(scalar_projections.mean())\n",
    "    degrees = np.rad2deg(np.arccos(scalar_projections))\n",
    "\n",
    "    y_values = scalar_projections\n",
    "\n",
    "    batch_size = scalar_projections.shape[-1]\n",
    "\n",
    "    # x_values_flatten = sum(\n",
    "    #     [\n",
    "    #         [f\"{l}-mid\"] * batch_size + [f\"{l}-post\"] * batch_size\n",
    "    #         for l in range(num_layers)\n",
    "    #     ],\n",
    "    #     [],\n",
    "    # )\n",
    "    x_values = sum([[f\"{l}\", f\"{l}-post\"] for l in range(num_layers)], [])\n",
    "    x_values = [str(i) for i in range(2 * num_layers)]\n",
    "\n",
    "    # variance\n",
    "    fig.add_trace(\n",
    "        variance_plot(\n",
    "            x=x_values,\n",
    "            y=torch.tensor(y_values).reshape(-1, degrees.shape[-1]),\n",
    "            yaxis=\"y\",\n",
    "            fillcolor=colour_map_opaque[category],\n",
    "            showlegend=False,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # mean\n",
    "    ## for legend\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values.mean(axis=-1).flatten(),\n",
    "            mode=\"lines+markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=True,\n",
    "            name=category,\n",
    "        )\n",
    "    )\n",
    "    ## for lines\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values.mean(axis=-1).flatten(),\n",
    "            mode=\"lines\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map_light[category], size=3),\n",
    "            showlegend=False,\n",
    "            name=category,\n",
    "        )\n",
    "    )\n",
    "    ## for markers\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=x_values,\n",
    "            y=y_values.mean(axis=-1).flatten(),\n",
    "            mode=\"markers\",\n",
    "            yaxis=\"y\",\n",
    "            marker=dict(color=colour_map[category], size=3),\n",
    "            showlegend=False,\n",
    "            name=category,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    activations -= 2 * einops.einsum(\n",
    "        np.maximum(scalar_projections, 0),\n",
    "        direction,\n",
    "        \"layer resid_module batch_size, dim -> layer resid_module batch_size dim\",\n",
    "    )\n",
    "    scalar_projections = einops.einsum(\n",
    "        activations,\n",
    "        direction,\n",
    "        \"... batch_size dim, ... dim -> ... batch_size\",\n",
    "    )\n",
    "    print(category)\n",
    "    print(scalar_projections.mean())\n",
    "    degrees = np.rad2deg(np.arccos(scalar_projections))\n",
    "\n",
    "    y_values = scalar_projections\n",
    "\n",
    "\n",
    "module_names = [\"mid\", \"post\"]\n",
    "fig.update_layout(\n",
    "    grid=dict(rows=1, columns=1),\n",
    "    # yaxis=dict(tickformat=\".2E\"),\n",
    "    plot_bgcolor=\"white\",\n",
    "    xaxis=dict(\n",
    "        type=\"category\",\n",
    "        dtick=4,\n",
    "        title=dict(text=\"Extraction Point\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    yaxis=dict(\n",
    "        title=dict(text=\"Scalar Projections\", font=dict(size=20)),\n",
    "        gridcolor=\"lightgrey\",\n",
    "        zeroline=False,\n",
    "        tickfont=dict(size=18),\n",
    "    ),\n",
    "    hovermode=\"x unified\",\n",
    "    height=250,\n",
    "    width=600,\n",
    "    # title=(\n",
    "    #     \"Scalar projections of activations at each layer onto the chosen refusal direction\"\n",
    "    #     f\" ({chosen_layer}-{module_names[chosen_act_idx]})\"\n",
    "    # ),\n",
    "    # yaxis=dict(matches=None),\n",
    "    margin=dict(l=20, r=20, t=20, b=20),\n",
    "    legend=dict(x=0.05, y=0.95, font=dict(size=18)),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"prj_onto_refusal_dir.pdf\", scale=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sanity check\n",
    "\n",
    "print(harmful_acts.shape)\n",
    "a = harmful_acts[chosen_layer, chosen_act_idx, 0, chosen_token]\n",
    "an = a / a.norm()\n",
    "an = an.cpu().float()\n",
    "b = harmful_acts_normed[chosen_layer, chosen_act_idx, 0, chosen_token].copy()\n",
    "\n",
    "print(an.dtype)\n",
    "print(b.dtype)\n",
    "\n",
    "print(an, np.linalg.norm(an))\n",
    "print(b, np.linalg.norm(b))\n",
    "np.testing.assert_allclose(an, b, rtol=10e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scalar projections of weights at each SelfAttn and MLP layer onto the refusal direction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_names = sum([[f\"{i}-mid\", f\"{i}-post\"] for i in range(num_layers)], [])\n",
    "direction = refusal_dirs[chosen_layer, chosen_act_idx].copy()\n",
    "direction /= np.linalg.norm(direction)\n",
    "\n",
    "x_values = []\n",
    "prj_values = []\n",
    "sum_magnitude = []\n",
    "\n",
    "for i in range(num_layers):\n",
    "    W = model.blocks[i].attn.W_O\n",
    "    prjs = W.detach().cpu().float().numpy() @ direction\n",
    "    prjs = prjs.flatten()\n",
    "    prj_values.append(prjs)\n",
    "    x_values.extend([layer_names[i * 2]] * prjs.shape[0])\n",
    "    sum_magnitude.append(np.sum(np.abs(prjs)))\n",
    "\n",
    "    W = model.blocks[i].mlp.W_out\n",
    "    prjs = W.detach().cpu().float().numpy() @ direction\n",
    "    prjs = prjs.flatten()\n",
    "    prj_values.append(prjs)\n",
    "    x_values.extend([layer_names[i * 2 + 1]] * prjs.shape[0])\n",
    "    sum_magnitude.append(np.sum(np.abs(prjs)))\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_trace(\n",
    "    go.Box(\n",
    "        x=x_values,\n",
    "        y=np.hstack(prj_values),\n",
    "        boxmean=True,\n",
    "        marker_color=px.colors.qualitative.Plotly[3],\n",
    "    )\n",
    ")\n",
    "fig.update_layout(\n",
    "    title=(\n",
    "        \"Scalar projections of weights at each layer onto the refusal direction\"\n",
    "        f\" ({chosen_layer}-{module_names[chosen_act_idx]})\"\n",
    "    ),\n",
    "    yaxis_range=[-0.5, 0.5],\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig = px.line(x=layer_names, y=sum_magnitude, markers=True)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2EoxY5i1CWe3"
   },
   "source": [
    "## Constructing the steering plane\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute the PCA from candidate directions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refusal_dirs_flatten = refusal_dirs.reshape(-1, refusal_dirs.shape[-1])\n",
    "refusal_dirs_flatten.shape\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "pca_model = PCA().fit(refusal_dirs_flatten)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vars = pca_model.explained_variance_ratio_\n",
    "vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "components = pca_model.components_\n",
    "\n",
    "print(refusal_dirs_flatten.shape)\n",
    "print(components.shape)\n",
    "np.degrees(np.arccos(refusal_dirs_flatten @ components[-1]))\n",
    "\n",
    "# another adhoc attempt to find the best direction is to take the mean of the candidates\n",
    "mean_d = refusal_dirs_flatten.mean(axis=0)\n",
    "mean_d /= np.linalg.norm(mean_d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# angle between each component and the chosen direction\n",
    "print(chosen_layer, chosen_act_idx)\n",
    "np.degrees(np.arccos(components @ refusal_dirs[chosen_layer][chosen_act_idx]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize of the candidate directions on the steering plane\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# first basis is the chosen direction (in this case, the one with the highest similarity)\n",
    "print(max_mean_cosine_layer, max_mean_cosine_act_idx)\n",
    "u1 = refusal_dirs[max_mean_cosine_layer][max_mean_cosine_act_idx].copy()\n",
    "\n",
    "# second basis is the first principal component\n",
    "u2 = components[0].copy()\n",
    "\n",
    "b1 = u1 / np.linalg.norm(u1)\n",
    "b2 = u2 - (u2 @ b1) * b1\n",
    "b2 /= np.linalg.norm(b2)\n",
    "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
    "\n",
    "prj_matrix = np.column_stack([b1, b2])\n",
    "refusal_dirs_mapped = refusal_dirs_flatten @ prj_matrix\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "norms = np.linalg.norm(refusal_dirs_mapped)\n",
    "x = refusal_dirs_mapped[:, 0] / norms\n",
    "y = refusal_dirs_mapped[:, 1] / norms\n",
    "angle = np.arctan2(y, x)\n",
    "\n",
    "for point, label in zip(\n",
    "    [u1 @ prj_matrix, u2 @ prj_matrix], [\"chosen<br>direction\", \"1st PC\"]\n",
    "):\n",
    "    fig.add_annotation(\n",
    "        # hovertext=str(i),\n",
    "        ax=0,\n",
    "        ay=0,\n",
    "        x=point[0],\n",
    "        y=point[1],\n",
    "        axref=\"x\",\n",
    "        ayref=\"y\",\n",
    "        showarrow=True,\n",
    "        arrowhead=2,\n",
    "        arrowwidth=2,\n",
    "        xanchor=\"right\",\n",
    "        yanchor=\"top\",\n",
    "        opacity=0.5,\n",
    "    )\n",
    "    fig.add_annotation(\n",
    "        x=point[0],\n",
    "        y=point[1],\n",
    "        text=label,\n",
    "        font=dict(size=22),\n",
    "        showarrow=False,\n",
    "        yshift=30,\n",
    "        xshift=20,\n",
    "    )\n",
    "\n",
    "points = go.Scatter(\n",
    "    x=refusal_dirs_mapped[:, 0],\n",
    "    y=refusal_dirs_mapped[:, 1],\n",
    "    text=[str(i) for i in range(len(refusal_dirs_mapped))],\n",
    "    mode=\"markers\",\n",
    "    marker=dict(\n",
    "        symbol=\"arrow\",\n",
    "        angle=90 - np.degrees(angle),\n",
    "        # size=[i + 7 if not np.isnan(i) else 0 for i in norms.flatten()],\n",
    "        size=20,\n",
    "        # color=np.linspace(0, refusal_dirs_mapped.shape[0], refusal_dirs_mapped.shape[0]),\n",
    "        color=[i for i in range(refusal_dirs_mapped.shape[0])],\n",
    "        # color=color.flatten(),\n",
    "        # colorscale=\"Phase\",\n",
    "        showscale=True,\n",
    "    ),\n",
    "    name=\"layers\",\n",
    "    showlegend=True,\n",
    ")\n",
    "fig.add_trace(points)\n",
    "\n",
    "fig.add_annotation(\n",
    "    xref=\"paper\",\n",
    "    yref=\"paper\",\n",
    "    text=\"Extraction<br>Point\",\n",
    "    font=dict(size=22),\n",
    "    showarrow=False,\n",
    "    x=1.17,\n",
    "    y=-0.15,\n",
    "    # yshift=20,\n",
    "    # xshift=20\n",
    ")\n",
    "\n",
    "\n",
    "fig.update_layout(\n",
    "    # plot_bgcolor=\"white\",\n",
    "    autosize=False,\n",
    "    # width=800,\n",
    "    height=600,\n",
    "    # yaxis_range=[-1.0, 1.0],\n",
    "    # xaxis_range=[-0.4, 1.4],\n",
    "    yaxis_scaleanchor=\"x\",\n",
    "    yaxis_scaleratio=1,\n",
    "    xaxis_dtick=0.5,\n",
    "    yaxis_dtick=0.5,\n",
    "    font=dict(size=22),\n",
    "    margin=dict(l=0, r=100, t=0, b=75),\n",
    "    legend=dict(visible=False),\n",
    "    # xaxis=dict(gridcolor=\"grey\"),\n",
    "    # yaxis=dict(gridcolor=\"grey\"),\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(VISUALIZATION_DIR / \"steering_plane.pdf\", scale=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Steering by rotation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rotation utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_rotation_matrix(degree, basis1, basis2):\n",
    "    assert len(basis1.shape) == 1\n",
    "    assert len(basis2.shape) == 1\n",
    "    assert basis1.shape == basis2.shape\n",
    "\n",
    "    n = basis1.shape[-1]\n",
    "\n",
    "    if degree % 360 == 0:\n",
    "        return np.eye(n)\n",
    "\n",
    "    # ensure bases are orthonormal\n",
    "    u = basis1 / np.linalg.norm(basis1)\n",
    "    v = basis2 - (basis2 @ u) * u\n",
    "    v /= np.linalg.norm(v)\n",
    "\n",
    "    theta = np.deg2rad(degree)\n",
    "    cos_theta = np.cos(theta)\n",
    "    sin_theta = np.sin(theta)\n",
    "    # print(cos_theta, sin_theta)\n",
    "\n",
    "    # rotate counter-clockwise\n",
    "    R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
    "\n",
    "    uv = np.column_stack([u, v])\n",
    "    R = np.eye(n) - (np.outer(u, u) + np.outer(v, v)) + uv @ R_theta @ uv.T\n",
    "\n",
    "    return R\n",
    "\n",
    "\n",
    "# sanity check\n",
    "print(chosen_layer, chosen_act_idx)\n",
    "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
    "b1 = components[-1].copy()\n",
    "b2 = mean_d.copy()\n",
    "\n",
    "b1 = b1 / np.linalg.norm(b1)\n",
    "b2 = b2 - (b2 @ b1) * b1\n",
    "b2 /= np.linalg.norm(b2)\n",
    "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
    "\n",
    "deg = np.rad2deg(np.arccos(b1 @ b2))\n",
    "print(deg)\n",
    "\n",
    "R = get_rotation_matrix(30, b1, b2)\n",
    "\n",
    "u = P @ R @ d\n",
    "u /= np.linalg.norm(u)\n",
    "v = P @ d\n",
    "v /= np.linalg.norm(v)\n",
    "\n",
    "print(np.rad2deg(np.arccos((R @ d) @ d)))\n",
    "print(np.rad2deg(np.arccos(u @ v)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rotate_to_target(x, target_degree, basis1, basis2):\n",
    "    assert len(basis1.shape) == 1\n",
    "    assert len(basis2.shape) == 1\n",
    "    assert basis1.shape == basis2.shape\n",
    "\n",
    "    n = basis1.shape[-1]\n",
    "\n",
    "    # ensure bases are orthonormal\n",
    "    u = basis1 / np.linalg.norm(basis1)\n",
    "    v = basis2 - (basis2 @ u) * u\n",
    "    v /= np.linalg.norm(v)\n",
    "\n",
    "    theta = np.deg2rad(target_degree)\n",
    "    cos_theta = np.cos(theta)\n",
    "    sin_theta = np.sin(theta)\n",
    "\n",
    "    P = np.outer(u, u) + np.outer(v, v)\n",
    "\n",
    "    # rotate counter-clockwise\n",
    "    R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
    "\n",
    "    uv = np.column_stack([u, v])\n",
    "\n",
    "    rotated_component = uv @ R_theta @ np.array([1, 0])\n",
    "    Px = x @ P\n",
    "    scale = np.linalg.norm(Px, axis=-1, keepdims=True)\n",
    "\n",
    "    result = x - Px + scale * rotated_component\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "# sanity check\n",
    "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
    "b1 = components[-1].copy()\n",
    "b2 = mean_d.copy()\n",
    "b1 = b1 / np.linalg.norm(b1)\n",
    "b2 = b2 - (b2 @ b1) * b1\n",
    "b2 /= np.linalg.norm(b2)\n",
    "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
    "\n",
    "rd = rotate_to_target(d, 60, b1, b2)\n",
    "\n",
    "u = P @ d\n",
    "u /= np.linalg.norm(u)\n",
    "deg = np.rad2deg(np.arccos(u @ b1))\n",
    "R = get_rotation_matrix(60 - deg, b1, b2)\n",
    "print(R @ d)\n",
    "print(rd)\n",
    "\n",
    "print(d.shape)\n",
    "print(np.linalg.norm(d, axis=-1, keepdims=True) * np.array([1, 2]))\n",
    "2 * np.array([1, 2])\n",
    "rotate_to_target(np.random.rand(5, 8, d.shape[0]), 60, b1, b2).shape\n",
    "print(chosen_layer, chosen_act_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Statistics of candidate directions on the steering plane\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(chosen_layer, chosen_act_idx)\n",
    "d = refusal_dirs[chosen_layer][chosen_act_idx].copy()\n",
    "b1 = mean_d.copy()\n",
    "b2 = components[0].copy()\n",
    "\n",
    "b1 = b1 / np.linalg.norm(b1)\n",
    "b2 = b2 - (b2 @ b1) * b1\n",
    "b2 /= np.linalg.norm(b2)\n",
    "\n",
    "P = np.outer(b1, b1) + np.outer(b2, b2)\n",
    "\n",
    "print(mean_d)\n",
    "print(mean_d @ P)\n",
    "\n",
    "proj = refusal_dirs_flatten @ P\n",
    "proj_norm = np.linalg.norm(proj, axis=-1)\n",
    "\n",
    "fig = px.line(x=layer_names, y=proj_norm)\n",
    "fig.update_layout(\n",
    "    title=\"Norm of projections of candidate directions on the steering plane\",\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "proj_normed = proj / proj_norm[:, None]\n",
    "proj_angle = np.rad2deg(np.arccos(proj_normed @ b1))\n",
    "\n",
    "fig = px.line(x=layer_names, y=proj_angle)\n",
    "fig.update_layout(\n",
    "    title=(\n",
    "        \"Angles between the projections of candiate directions and the chosen\"\n",
    "        f\" direction on the steering plane ({chosen_layer}-{chosen_act_idx})\"\n",
    "    ),\n",
    ")\n",
    "fig.show()\n",
    "print(chosen_layer, chosen_act_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the steering config\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "from transformers import AutoModel\n",
    "from accelerate import init_empty_weights\n",
    "\n",
    "with init_empty_weights():\n",
    "    hf_model = AutoModel.from_pretrained(MODEL_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Let's look at all the names of the modules in the model so that we know which ones to\n",
    "# apply steering to (we're interested in the normalization layers before each MLP and\n",
    "# attention layer)\n",
    "pprint(list(n for n, m in hf_model.named_modules()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from pprint import pprint\n",
    "\n",
    "target_modules = [\"mid\", \"post\"]\n",
    "# because in the model's architecture, there are no explicit modules for the residual stream (pre/mid/post)\n",
    "# thus let's use the inputs of layernorm modules as equivalent\n",
    "# resid-pre = input of input_layernorm\n",
    "# resid-mid = input of post_attention_layernorm / pre_feedforward_layernorm (gemma)\n",
    "# resid-post = input of the next input_layernorm\n",
    "layernorm_modules = [\"input_layernorm\", \"post_attention_layernorm\"]\n",
    "if \"gemma\" in MODEL_PATH:\n",
    "    layernorm_modules += [\"post_attention_layernorm\", \"post_feedforward_layernorm\"]\n",
    "\n",
    "\n",
    "mean_d = refusal_dirs_flatten.mean(axis=0)\n",
    "mean_d /= np.linalg.norm(mean_d)\n",
    "\n",
    "print(chosen_layer, chosen_act_idx)\n",
    "\n",
    "# saving various steering configs\n",
    "for first_direction, first_dir_name in [\n",
    "    (\n",
    "        refusal_dirs[max_norm_layer][max_norm_act_idx].copy(),\n",
    "        f\"dir_max_norm_{max_norm_layer}_{target_modules[max_norm_act_idx]}\",\n",
    "    ),\n",
    "    (\n",
    "        refusal_dirs[max_mean_cosine_layer][max_mean_cosine_act_idx].copy(),\n",
    "        f\"dir_max_sim_{max_mean_cosine_layer}_{target_modules[max_mean_cosine_act_idx]}\",\n",
    "    ),\n",
    "    (mean_d.copy(), \"dir_mean\"),\n",
    "]:\n",
    "\n",
    "    second_direction = components[0].copy()\n",
    "\n",
    "    num_layers = refusal_dirs.shape[0]\n",
    "    steering_config = {}\n",
    "    for layer_idx in range(num_layers):\n",
    "        for module in layernorm_modules:\n",
    "            if module != \"input_layernorm\":\n",
    "                module_name = f\"model.layers.{layer_idx}.{module}\"\n",
    "            elif layer_idx < num_layers - 1:\n",
    "                module_name = f\"model.layers.{layer_idx + 1}.{module}\"\n",
    "            else:\n",
    "                continue\n",
    "\n",
    "            steering_config[module_name] = {\n",
    "                \"mode\": \"rotate_to\",\n",
    "                \"first_direction\": first_direction,\n",
    "                \"second_direction\": second_direction,\n",
    "            }\n",
    "\n",
    "    # pprint(steering_config)\n",
    "    output_name = f\"{METHOD_PREFIX}steering_config-en-{first_dir_name}-pca_0.npy\"\n",
    "    print(output_name)\n",
    "    np.save(OUTPUT_DIR / output_name, steering_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing Angular Steering\n",
    "\n",
    "(this takes a while because inference with transformers-lens is slow)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Try with your own prompt\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_rotate_to_target_func(target_degree, basis1, basis2):\n",
    "#     assert len(basis1.shape) == 1\n",
    "#     assert len(basis2.shape) == 1\n",
    "#     assert basis1.shape == basis2.shape\n",
    "\n",
    "#     n = basis1.shape[-1]\n",
    "\n",
    "#     # ensure bases are orthonormal\n",
    "#     u = basis1 / np.linalg.norm(basis1)\n",
    "#     v = basis2 - (basis2 @ u) * u\n",
    "#     v /= np.linalg.norm(v)\n",
    "\n",
    "#     theta = np.deg2rad(target_degree)\n",
    "#     cos_theta = np.cos(theta)\n",
    "#     sin_theta = np.sin(theta)\n",
    "\n",
    "#     P = np.outer(u, u) + np.outer(v, v)\n",
    "\n",
    "#     # rotate counter-clockwise\n",
    "#     R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
    "\n",
    "#     uv = np.column_stack([u, v])\n",
    "\n",
    "#     rotated_component = uv @ R_theta @ np.array([1, 0])\n",
    "\n",
    "#     def __func(x: Tensor):\n",
    "#         Px = x @ torch.tensor(P, device=x.device, dtype=x.dtype)\n",
    "#         scale = Px.norm(dim=-1, keepdim=True)\n",
    "\n",
    "#         result = (\n",
    "#             x\n",
    "#             - Px\n",
    "#             + scale * torch.tensor(rotated_component, device=x.device, dtype=x.dtype)\n",
    "#         )\n",
    "\n",
    "#         return result\n",
    "\n",
    "#     return __func\n",
    "\n",
    "# transformation_func = get_rotate_to_target_func(\n",
    "#     target_degree=180, basis1=refusal_dir.copy(), basis2=components[0].copy()\n",
    "# )\n",
    "\n",
    "# fwd_hooks = [\n",
    "#     (\n",
    "#         utils.get_act_name(act_name, layer),\n",
    "#         functools.partial(\n",
    "#             activation_rotation_hook,\n",
    "#             transformation_func=transformation_func,\n",
    "#         ),\n",
    "#     )\n",
    "#     for layer in intervention_layers\n",
    "#     for act_idx, act_name in enumerate([\"resid_pre\", \"resid_mid\"])\n",
    "# ]\n",
    "\n",
    "# answer = get_generations(\n",
    "#     model,\n",
    "#     [\"how to pick a lock\"],\n",
    "#     model.tokenizer,\n",
    "#     fwd_hooks=fwd_hooks,\n",
    "#     max_tokens_generated=256,\n",
    "# )\n",
    "\n",
    "# for a in answer:\n",
    "#     print(a.split(model.tokenizer.eos_token)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Steered generation on test set at every 10 steering degree\n",
    "\n",
    "(very very slow)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# MAX_GENERATION_TOKENS = 512\n",
    "# NUM_TEST_SAMPLES = 4\n",
    "# test_samples = harmful_inst_test[:NUM_TEST_SAMPLES]\n",
    "\n",
    "\n",
    "# def get_rotate_to_target_func(target_degree, basis1, basis2):\n",
    "#     assert len(basis1.shape) == 1\n",
    "#     assert len(basis2.shape) == 1\n",
    "#     assert basis1.shape == basis2.shape\n",
    "\n",
    "#     n = basis1.shape[-1]\n",
    "\n",
    "#     # ensure bases are orthonormal\n",
    "#     u = basis1 / np.linalg.norm(basis1)\n",
    "#     v = basis2 - (basis2 @ u) * u\n",
    "#     v /= np.linalg.norm(v)\n",
    "\n",
    "#     theta = np.deg2rad(target_degree)\n",
    "#     cos_theta = np.cos(theta)\n",
    "#     sin_theta = np.sin(theta)\n",
    "\n",
    "#     P = np.outer(u, u) + np.outer(v, v)\n",
    "\n",
    "#     # rotate counter-clockwise\n",
    "#     R_theta = [[cos_theta, -sin_theta], [sin_theta, cos_theta]]\n",
    "\n",
    "#     uv = np.column_stack([u, v])\n",
    "\n",
    "#     rotated_component = uv @ R_theta @ np.array([1, 0])\n",
    "\n",
    "#     def __func(x: Tensor):\n",
    "#         Px = x @ torch.tensor(P, device=x.device, dtype=x.dtype)\n",
    "#         scale = Px.norm(dim=-1, keepdim=True)\n",
    "\n",
    "#         result = (\n",
    "#             x\n",
    "#             - Px\n",
    "#             + scale * torch.tensor(rotated_component, device=x.device, dtype=x.dtype)\n",
    "#         )\n",
    "\n",
    "#         return result\n",
    "\n",
    "#     return __func\n",
    "\n",
    "\n",
    "# def activation_rotation_hook(\n",
    "#     activation: Float[Tensor, \"... d_act\"],\n",
    "#     hook: HookPoint,\n",
    "#     transformation_func,\n",
    "# ):\n",
    "#     return transformation_func(activation)\n",
    "\n",
    "\n",
    "# if not \"baseline_generations\" in locals():\n",
    "#     # if True:\n",
    "#     baseline_generations = get_generations(\n",
    "#         model,\n",
    "#         test_samples,\n",
    "#         model.tokenizer,\n",
    "#         fwd_hooks=[],\n",
    "#         max_tokens_generated=MAX_GENERATION_TOKENS,\n",
    "#     )\n",
    "\n",
    "# intervention_generations = {}\n",
    "# refusal_dir = refusal_dirs[chosen_layer][chosen_act_idx]\n",
    "# for degree in range(0, 360, 10):\n",
    "#     intervention_layers = list(range(model.cfg.n_layers))\n",
    "#     print(\"degree\", degree)\n",
    "#     print(intervention_layers)\n",
    "\n",
    "#     if degree in intervention_generations:\n",
    "#         continue\n",
    "\n",
    "#     transformation_func = get_rotate_to_target_func(\n",
    "#         target_degree=degree, basis1=refusal_dir.copy(), basis2=components[0].copy()\n",
    "#     )\n",
    "\n",
    "#     fwd_hooks = [\n",
    "#         (\n",
    "#             utils.get_act_name(act_name, layer),\n",
    "#             functools.partial(\n",
    "#                 activation_rotation_hook,\n",
    "#                 transformation_func=transformation_func,\n",
    "#             ),\n",
    "#         )\n",
    "#         for layer in intervention_layers\n",
    "#         for act_idx, act_name in enumerate([\"resid_pre\", \"resid_mid\"])\n",
    "#     ]\n",
    "\n",
    "#     intervention_generations[degree] = get_generations(\n",
    "#         model,\n",
    "#         test_samples,\n",
    "#         model.tokenizer,\n",
    "#         fwd_hooks=fwd_hooks,\n",
    "#         max_tokens_generated=MAX_GENERATION_TOKENS,\n",
    "#     )\n",
    "\n",
    "# for i in range(NUM_TEST_SAMPLES):\n",
    "#     print(f\"INSTRUCTION {i}: {repr(test_samples[i])}\")\n",
    "#     print(Fore.GREEN + f\"BASELINE COMPLETION:\")\n",
    "#     print(\n",
    "#         textwrap.fill(\n",
    "#             baseline_generations[i],\n",
    "#             width=100,\n",
    "#             initial_indent=\"\\t\",\n",
    "#             subsequent_indent=\"\\t\",\n",
    "#         )\n",
    "#     )\n",
    "#     print(Fore.RESET)\n",
    "#     for degree in sorted(intervention_generations.keys()):\n",
    "#         print(Fore.RED + f\"INTERVENTION COMPLETION (degree {degree}):\")\n",
    "#         print(intervention_generations[degree][i].split(model.tokenizer.eos_token)[0])\n",
    "#         # print(\n",
    "#         #     textwrap.fill(\n",
    "#         #         intervention_generations[extraction_layer][i],\n",
    "#         #         width=100,\n",
    "#         #         initial_indent=\"\\t\",\n",
    "#         #         subsequent_indent=\"\\t\",\n",
    "#         #     )\n",
    "#         # )\n",
    "#         print(Fore.RESET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "angular_steering",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "3877270cf4bc42a9b6142cce7a5d8c54": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ec8f6f360a2243b0ac98d34e825ba378",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f2ee188bfaa84e9680dbc296b1adbef6",
      "value": 2
     }
    },
    "89797f6e82104058af92e3ceb094af66": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "89ee88168c474e9fbcf4a17f1483eff4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c362d50107dd4a2db0d1a79da2af8d57",
      "placeholder": "​",
      "style": "IPY_MODEL_ffa85c694b694425999187b346c7ecfe",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "9a5611a341ed4673aaaf2f463f685d7c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_e973493cd6d14381bb4ad2f82417e8a9",
      "placeholder": "​",
      "style": "IPY_MODEL_89797f6e82104058af92e3ceb094af66",
      "value": " 2/2 [00:18&lt;00:00,  8.85s/it]"
     }
    },
    "a2de63dfbd6c485e841c6fcd1fefe451": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ad063e2c68a44f009bfab68c141c09be": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_89ee88168c474e9fbcf4a17f1483eff4",
       "IPY_MODEL_3877270cf4bc42a9b6142cce7a5d8c54",
       "IPY_MODEL_9a5611a341ed4673aaaf2f463f685d7c"
      ],
      "layout": "IPY_MODEL_a2de63dfbd6c485e841c6fcd1fefe451"
     }
    },
    "c362d50107dd4a2db0d1a79da2af8d57": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e973493cd6d14381bb4ad2f82417e8a9": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ec8f6f360a2243b0ac98d34e825ba378": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f2ee188bfaa84e9680dbc296b1adbef6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "ffa85c694b694425999187b346c7ecfe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
