{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "82acAhWYGIPx"
   },
   "source": [
    "# Angular Steering\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j7hOtw7UHXdD"
   },
   "source": [
    "This notebook contains:\n",
    "\n",
    "- The process of extracting the refusal direction and constructing the steering plane.\n",
    "- Visualization of the activation, extracted directions and constructed steering planes.\n",
    "- The creation of the steering config that can be used with our fork of vLLM.\n",
    "\n",
    "This notebook is inspired by https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fcxHyDZw6b86"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dLeei4-T6Wef"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping colorama nbformat plotly datasets pandas scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_vhhwl-2-jPg"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import functools\n",
    "import einops\n",
    "import requests\n",
    "import pandas as pd\n",
    "import io\n",
    "import textwrap\n",
    "import gc\n",
    "import numpy as np\n",
    "import plotly\n",
    "import os\n",
    "\n",
    "\n",
    "from pathlib import Path\n",
    "from datasets import load_dataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "from tqdm import tqdm\n",
    "from torch import Tensor\n",
    "from typing import List\n",
    "from transformer_lens import HookedTransformer, utils, ActivationCache\n",
    "from transformer_lens.hook_points import HookPoint\n",
    "from transformers import AutoTokenizer\n",
    "from jaxtyping import Float, Int\n",
    "from colorama import Fore\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "\n",
    "\n",
    "from transformers import AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6ZOoJagxD49V"
   },
   "source": [
    "### Models and configs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 191,
     "referenced_widgets": [
      "ad063e2c68a44f009bfab68c141c09be",
      "89ee88168c474e9fbcf4a17f1483eff4",
      "3877270cf4bc42a9b6142cce7a5d8c54",
      "9a5611a341ed4673aaaf2f463f685d7c",
      "a2de63dfbd6c485e841c6fcd1fefe451",
      "c362d50107dd4a2db0d1a79da2af8d57",
      "ffa85c694b694425999187b346c7ecfe",
      "ec8f6f360a2243b0ac98d34e825ba378",
      "f2ee188bfaa84e9680dbc296b1adbef6",
      "e973493cd6d14381bb4ad2f82417e8a9",
      "89797f6e82104058af92e3ceb094af66"
     ]
    },
    "id": "Vnp65Vsg5x-5",
    "outputId": "25fb5805-fe31-44b0-8f73-6fabc230d261"
   },
   "outputs": [],
   "source": [
    "# Choose one model for the experimentw\n",
    "MODEL_PATH = (\n",
    "    # \"Qwen/Qwen2.5-3B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-7B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "    # \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "    # \"meta-llama/Llama-3.2-3B-Instruct\"\n",
    "    # \"meta-llama/Llama-3.1-8B-Instruct\"\n",
    "    \"google/gemma-2-9b-it\"\n",
    ")\n",
    "extraction_point = list(range(44, 45))\n",
    "k_grid = [10]\n",
    "MODEL_NAME = MODEL_PATH.split(\"/\")[-1]\n",
    "\n",
    "# Pre MLP (k) + Pre Attn (k+1): Nothing (Even = False, Odd = False)\n",
    "# Post Attn (k) + Pre Attn (k+1): A_steering (Even = True, Odd = False)\n",
    "# Pre MLP (k) + Post MLP (k): B_steering (Even = False, Odd = True)\n",
    "# Post Attn (k) + Post MLP (k): C_steering (Even = True, Odd = True)\n",
    "GEMMA_RELOCATE_MODE = True\n",
    "if not (\"gemma\" in MODEL_PATH or \"Gemma\" in MODEL_PATH):\n",
    "    GEMMA_RELOCATE_MODE = False\n",
    "if GEMMA_RELOCATE_MODE:\n",
    "    GEMMA_RELOCATE_EVEN = True\n",
    "    GEMMA_RELOCATE_ODD = True\n",
    "else:\n",
    "    GEMMA_RELOCATE_EVEN = False\n",
    "    GEMMA_RELOCATE_ODD = False\n",
    "\n",
    "gemma_save_file_append = \"\"\n",
    "if GEMMA_RELOCATE_EVEN and not GEMMA_RELOCATE_ODD:\n",
    "    gemma_save_file_append = \"A_\"\n",
    "elif not GEMMA_RELOCATE_EVEN and GEMMA_RELOCATE_ODD:\n",
    "    gemma_save_file_append = \"B_\"\n",
    "elif GEMMA_RELOCATE_EVEN and GEMMA_RELOCATE_ODD:\n",
    "    gemma_save_file_append = \"C_\"\n",
    "\n",
    "DEVICE = \"cuda:0\"\n",
    "lmda = \"adaptive\"\n",
    "sim = \"adaptive_gaussian\"\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "max_pcs = 15\n",
    "\n",
    "OUTPUT_PARENT_DIR = Path(\"output\") / f\"{MODEL_NAME}\" / \"CHaRS_PCT\"\n",
    "OUTPUT_PARENT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "VISUALIZATION_PARENT_DIR = Path(\"visualization\") / f\"{MODEL_NAME}\" / \"CHaRS_PCT\"\n",
    "\n",
    "\n",
    "CACHE_DIR = Path(os.getcwd()) / \"huggingface\"\n",
    "MODEL_CACHE_DIR = CACHE_DIR / \"hub\"\n",
    "DATASETS_CACHE_DIR = CACHE_DIR / \"datasets\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,\n",
    "                                             device_map = \"auto\",\n",
    "                                             torch_dtype=torch.bfloat16,\n",
    "                                             cache_dir = MODEL_CACHE_DIR)\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH,\n",
    "                                          device_map = \"auto\",\n",
    "                                          torch_dtype=torch.bfloat16,\n",
    "                                          cache_dir = MODEL_CACHE_DIR)\n",
    "\n",
    "tokenizer.padding_side = \"left\"\n",
    "\n",
    "# store original chat template\n",
    "ORIGINAL_CHAT_TEMPLATE = tokenizer.chat_template\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# handle pad token for some model\n",
    "if not tokenizer.pad_token:\n",
    "    if \"qwen1\" in MODEL_PATH.lower():\n",
    "        tokenizer.pad_token = \"<|endoftext|>\"\n",
    "    elif tokenizer.eos_token:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    else:\n",
    "        raise ValueError(\"No pad token found in the tokenizer.\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rF7e-u20EFTe"
   },
   "source": [
    "### Load harmful / harmless datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5i1XcVIgHEE1"
   },
   "outputs": [],
   "source": [
    "def get_harmful_instructions():\n",
    "    url = \"https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv\"\n",
    "    response = requests.get(url)\n",
    "\n",
    "    dataset = pd.read_csv(io.StringIO(response.content.decode(\"utf-8\")))\n",
    "    instructions = dataset[\"goal\"].tolist()\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmless_instructions():\n",
    "    hf_path = \"tatsu-lab/alpaca\"\n",
    "    dataset = load_dataset(hf_path, cache_dir = DATASETS_CACHE_DIR)\n",
    "\n",
    "    # filter for instructions that do not have inputs\n",
    "    instructions = []\n",
    "    for i in range(len(dataset[\"train\"])):\n",
    "        if dataset[\"train\"][i][\"input\"].strip() == \"\":\n",
    "            instructions.append(dataset[\"train\"][i][\"instruction\"])\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmful_instructions_jp():\n",
    "    huggingface_id = \"kunishou/do-not-answer-ja\"\n",
    "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
    "\n",
    "    selected_types = [\n",
    "        # # Information Hazards\n",
    "        # \"Risks from leaking or inferring sensitive information (organization/gov)\",\n",
    "        # (\n",
    "        #     \"Compromise privacy by leaking or inferring private information\"\n",
    "        #     \" (person/individual)\"\n",
    "        # ),\n",
    "        # Malicious Uses\n",
    "        \"Assisting illegal activities\",\n",
    "        \"Nudging or advising users to perform unethical or unsafe actions\",\n",
    "        \"Reducing the cost of disinformation campaigns\",\n",
    "        # Discrimination, Exclusion, Toxicity, Hateful, Offensive\n",
    "        \"Social stereotypes and unfair discrimination\",\n",
    "        \"Toxic language (hate speech)\",\n",
    "        \"Adult Content\",\n",
    "        # # Misinformation Harms\n",
    "        # \"Causing material harm by disseminating misinformation e.g. in medicine or law\",\n",
    "        # \"Disseminating false or misleading information\",\n",
    "        # # Human-Chatbot Interaction Harms\n",
    "        # \"Mental Health or Overreliance Crisis\",\n",
    "        # \"Treat Chatbot as a Human\",\n",
    "    ]\n",
    "\n",
    "    instructions = []\n",
    "    for item in dataset[\"train\"]:\n",
    "        if item[\"types_of_harm\"] not in selected_types:\n",
    "            continue\n",
    "        instructions.append(item[\"question\"])\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test\n",
    "\n",
    "\n",
    "def get_harmless_instructions_jp():\n",
    "    huggingface_id = \"Lazycuber/alpaca-jp\"\n",
    "    dataset = load_dataset(huggingface_id, cache_dir = DATASETS_CACHE_DIR)\n",
    "\n",
    "    # filter for instructions that do not have inputs\n",
    "    instructions = []\n",
    "    for item in dataset[\"train\"]:\n",
    "        if item[\"input\"].strip() != \"\":\n",
    "            continue\n",
    "        inst = item[\"instruction\"]\n",
    "        inst = inst.strip(\"「」'\")\n",
    "        instructions.append(inst)\n",
    "\n",
    "    train, test = train_test_split(instructions, test_size=0.2, random_state=42)\n",
    "    return train, test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Rth8yvLZJsXs"
   },
   "outputs": [],
   "source": [
    "LANGUAGE = \"en\"\n",
    "\n",
    "if LANGUAGE == \"en\":\n",
    "    harmful_inst_train, harmful_inst_test = get_harmful_instructions()\n",
    "    harmless_inst_train, harmless_inst_test = get_harmless_instructions()\n",
    "elif LANGUAGE == \"jp\":\n",
    "    harmful_inst_train, harmful_inst_test = get_harmful_instructions_jp()\n",
    "    harmless_inst_train, harmless_inst_test = get_harmless_instructions_jp()\n",
    "\n",
    "print(f\"Train: {len(harmful_inst_train)} harmful, {len(harmless_inst_train)} harmless\")\n",
    "print(f\"Test: {len(harmful_inst_test)} harmful, {len(harmless_inst_test)} harmless\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Qv2ALDY_J44G",
    "outputId": "7cce7654-9592-4414-a322-27c6631f2e8e"
   },
   "outputs": [],
   "source": [
    "print(\"Harmful instructions:\")\n",
    "for i in range(4):\n",
    "    print(f\"\\t{harmful_inst_train[i]}\")\n",
    "print(\"Harmless instructions:\")\n",
    "for i in range(4):\n",
    "    print(f\"\\t{harmless_inst_train[i]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KOKYA61k8LWt"
   },
   "source": [
    "### Tokenization utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P8UPQSfpWOSK"
   },
   "outputs": [],
   "source": [
    "def instructions_to_chat_tokens(\n",
    "    tokenizer: AutoTokenizer,\n",
    "    instructions: List[str],\n",
    ") -> Int[Tensor, \"batch_size seq_len\"]:\n",
    "    if tokenizer.chat_template:\n",
    "        convos = [\n",
    "            [{\"role\": \"user\", \"content\": instruction}] for instruction in instructions\n",
    "        ]\n",
    "        return tokenizer.apply_chat_template(\n",
    "            convos,\n",
    "            padding=True,\n",
    "            truncation=False,\n",
    "            add_generation_prompt=True,\n",
    "            return_tensors=\"pt\",\n",
    "        )\n",
    "    else:\n",
    "        return tokenizer(\n",
    "            instructions, padding=True, truncation=False, return_tensors=\"pt\"\n",
    "        ).input_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W9O8dm0_EQRk"
   },
   "source": [
    "## Finding the \"refusal direction\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_template_suffix_toks(tokenizer):\n",
    "    # Since the padding is on the left side, the suffix of all samples are the same\n",
    "    # when using the same template.\n",
    "    # The activations on these suffix tokens are after the prompt has been processed,\n",
    "    # thus it's interesting to see how the activations differ between contrastive\n",
    "    # samples\n",
    "\n",
    "    # get the common suffix between 2 samples\n",
    "    toks = instructions_to_chat_tokens(tokenizer=tokenizer, instructions=[\"a\", \"b\"])\n",
    "    suffix = toks[0]\n",
    "    for i in range(len(toks[0]) - 1, -1, -1):\n",
    "        if toks[0][i] != toks[1][i]:\n",
    "            suffix = toks[0][i + 1 :]\n",
    "\n",
    "    return tokenizer.convert_ids_to_tokens(suffix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "\n",
    "def get_activations(\n",
    "    model,\n",
    "    model_name: str,\n",
    "    tokenizer,\n",
    "    instructions: List[str],\n",
    "    batch_size: int = BATCH_SIZE,\n",
    "    num_last_tokens: int = 1,\n",
    "    gemma_relocate_even: bool = False,\n",
    "    gemma_relocate_odd: bool = False,\n",
    "):\n",
    "    # tokenize instructions\n",
    "    toks = instructions_to_chat_tokens(\n",
    "        tokenizer=tokenizer, instructions=instructions\n",
    "    ).to(model.device)\n",
    "    attention_mask = (toks != tokenizer.pad_token_id).long().to(model.device)\n",
    "\n",
    "    activations_full = []\n",
    "\n",
    "    # For logging purposes\n",
    "    first_pass = True\n",
    "    first_gemma_even_pass = True\n",
    "    first_gemma_odd_pass = True\n",
    "    \n",
    "    for batch in tqdm.tqdm(range(0, toks.shape[0], batch_size)):\n",
    "\n",
    "        toks_batch = toks[batch: batch + batch_size]\n",
    "        attn_mask = attention_mask[batch: batch + batch_size]\n",
    "        print(toks_batch[0].device)\n",
    "\n",
    "        activations = []\n",
    "        handles = []\n",
    "        if \"gemma-3\" in model_name:\n",
    "            num_layers = len(model.language_model.model.layers)\n",
    "        else:\n",
    "            num_layers = len(model.model.layers)\n",
    "\n",
    "        # Adding Hooks to Layer Norms; \n",
    "        # For Simplicity of the indexing, we will follow the original which doesn't consider layer 0 input layernorm \n",
    "        # We shall not consider the final layer post (I suppose equivalent to the final layer norm after completion of all layers) as it isn't even intervened there\n",
    "        # There will be 2 * num_layers - 1 sets of activations\n",
    "        for layer_no in range(num_layers):\n",
    "\n",
    "            # Only extract Post Attention\n",
    "            if layer_no == 0:\n",
    "\n",
    "                def hook_fn(module, input, output):\n",
    "\n",
    "                    nonlocal activations\n",
    "\n",
    "                    if isinstance(output, tuple):\n",
    "                        assert len(output[0].shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[0][:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "                    else:\n",
    "                        assert len(output.shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "\n",
    "                    return output\n",
    "\n",
    "                if \"gemma-3\" in model_name:\n",
    "                    # Logging\n",
    "                    if first_pass:\n",
    "                        print(\"This is a Gemma 3 Model in Use!\")\n",
    "                        first_pass = False\n",
    "                    if not gemma_relocate_even:\n",
    "                        if first_gemma_even_pass:\n",
    "                            print(\"Even: Using Pre Feedforward LayerNorm!\")\n",
    "                            first_gemma_even_pass = False\n",
    "                        handle = model.language_model.model.layers[layer_no].pre_feedforward_layernorm.register_forward_hook(hook_fn)\n",
    "                    else:\n",
    "                        if first_gemma_even_pass:\n",
    "                            print(\"Even: Using Post Attention LayerNorm!\")\n",
    "                            first_gemma_even_pass = False\n",
    "                        handle = model.language_model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn)\n",
    "                elif \"gemma-2\" in model_name or \"Gemma-2\" in model_name:\n",
    "                    # Logging\n",
    "                    if first_pass:\n",
    "                        print(\"This is a Gemma 2 Model in Use!\")\n",
    "                        first_pass = False\n",
    "                    if not gemma_relocate_even:\n",
    "                        if first_gemma_even_pass:\n",
    "                            print(\"Even: Using Pre Feedforward LayerNorm!\")\n",
    "                            first_gemma_even_pass = False\n",
    "                        handle = model.model.layers[layer_no].pre_feedforward_layernorm.register_forward_hook(hook_fn)\n",
    "                    else:\n",
    "                        if first_gemma_even_pass:\n",
    "                            print(\"Even: Using Post Attention LayerNorm!\")\n",
    "                            first_gemma_even_pass = False\n",
    "                        handle = model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn)\n",
    "                else:\n",
    "                    handle = model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn)\n",
    "\n",
    "                handles.append(handle)\n",
    "\n",
    "            else:\n",
    "\n",
    "                def hook_fn_odd(module, input, output):\n",
    "\n",
    "                    nonlocal activations\n",
    "\n",
    "                    if isinstance(output, tuple):\n",
    "                        assert len(output[0].shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[0][:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "                    else:\n",
    "                        assert len(output.shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "\n",
    "                    return output\n",
    "                \n",
    "                def hook_fn_even(module, input, output):\n",
    "                    \n",
    "                    nonlocal activations\n",
    "\n",
    "                    if isinstance(output, tuple):\n",
    "                        assert len(output[0].shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[0][:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "                    else:\n",
    "                        assert len(output.shape) == 3, \"Output Shape is Not (B, L, D)\"\n",
    "                        activations.append(output[:, -num_last_tokens:, :].detach().clone().cpu())\n",
    "\n",
    "                    return output\n",
    "                \n",
    "                if \"gemma-3\" in model_name:\n",
    "                    if not gemma_relocate_odd:\n",
    "                        if first_gemma_odd_pass:\n",
    "                            print(\"Odd: Using Input LayerNorm!\")\n",
    "                            first_gemma_odd_pass = False\n",
    "                        handle_1 = model.language_model.model.layers[layer_no].input_layernorm.register_forward_hook(hook_fn_odd)\n",
    "                    else:\n",
    "                        if first_gemma_odd_pass:\n",
    "                            print(\"Odd: Using Post Feedforward LayerNorm!\")\n",
    "                            first_gemma_odd_pass = False\n",
    "                        handle_1 = model.language_model.model.layers[layer_no - 1].post_feedforward_layernorm.register_forward_hook(hook_fn_odd)\n",
    "                    if not gemma_relocate_even:\n",
    "                        handle_2 = model.language_model.model.layers[layer_no].pre_feedforward_layernorm.register_forward_hook(hook_fn_even)\n",
    "                    else:\n",
    "                        handle_2 = model.language_model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn_even)\n",
    "                elif \"gemma-2\" in model_name or \"Gemma-2\" in model_name:\n",
    "                    if not gemma_relocate_odd:\n",
    "                        if first_gemma_odd_pass:\n",
    "                            print(\"Odd: Using Input LayerNorm!\")\n",
    "                            first_gemma_odd_pass = False\n",
    "                        handle_1 = model.model.layers[layer_no].input_layernorm.register_forward_hook(hook_fn_odd)\n",
    "                    else:\n",
    "                        if first_gemma_odd_pass:\n",
    "                            print(\"Odd: Using Post Feedforward LayerNorm!\")\n",
    "                            first_gemma_odd_pass = False\n",
    "                        handle_1 = model.model.layers[layer_no].post_feedforward_layernorm.register_forward_hook(hook_fn_odd)\n",
    "                    if not gemma_relocate_even:\n",
    "                        handle_2 = model.model.layers[layer_no].pre_feedforward_layernorm.register_forward_hook(hook_fn_even)\n",
    "                    else:\n",
    "                        handle_2 = model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn_even)\n",
    "                else:\n",
    "                    handle_1 = model.model.layers[layer_no].input_layernorm.register_forward_hook(hook_fn_odd)\n",
    "                    handle_2 = model.model.layers[layer_no].post_attention_layernorm.register_forward_hook(hook_fn_even)\n",
    "\n",
    "                handles.append(handle_1)\n",
    "                handles.append(handle_2)\n",
    "\n",
    "        # Forward Pass\n",
    "        with torch.no_grad():\n",
    "            _ = model(toks_batch, attention_mask = attn_mask)\n",
    "\n",
    "        # Hook Removal\n",
    "        for handle in handles:\n",
    "            handle.remove()\n",
    "\n",
    "        # Activation Processing\n",
    "        activations = torch.stack(activations, dim = 0)\n",
    "        activations_full.append(activations)\n",
    "\n",
    "    # For acts: layers x resid_modules x batch x tokens x dim\n",
    "\n",
    "    activations_full = torch.cat(activations_full, dim = 1)\n",
    "    return activations_full"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract the activations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_INST_TRAIN = 512\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "# extraction points per decoder block\n",
    "act_names = [\"resid_mid\", \"resid_post\"]\n",
    "\n",
    "# get the template suffix tokens\n",
    "template_suffix_toks = get_template_suffix_toks(tokenizer)\n",
    "if not template_suffix_toks:\n",
    "    template_suffix_toks = [\"<last token>\"]\n",
    "\n",
    "# only get the activations of the template suffix tokens since these tokens are the same\n",
    "# for all samples\n",
    "num_last_tokens = len(template_suffix_toks)\n",
    "print(\"template_suffix_toks:\", template_suffix_toks)\n",
    "\n",
    "# load from cache if exists\n",
    "output_file = OUTPUT_PARENT_DIR / f\"{gemma_save_file_append}acts_harmful_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    "if output_file.exists():\n",
    "    print(\"Loading harmful activations from file\")\n",
    "    harmful_acts = np.load(output_file)\n",
    "    harmful_acts = torch.from_numpy(harmful_acts)\n",
    "else:\n",
    "    # get activations for harmful instructions then save to file\n",
    "    harmful_acts = get_activations(\n",
    "        model,\n",
    "        MODEL_NAME,\n",
    "        tokenizer,\n",
    "        harmful_inst_train[:N_INST_TRAIN],\n",
    "        batch_size=BATCH_SIZE,\n",
    "        num_last_tokens=num_last_tokens,\n",
    "        gemma_relocate_even=GEMMA_RELOCATE_EVEN,\n",
    "        gemma_relocate_odd=GEMMA_RELOCATE_ODD,\n",
    "    )\n",
    "    harmful_acts = harmful_acts.cpu().float()\n",
    "    np.save(output_file, harmful_acts.numpy())\n",
    "\n",
    "# load from cache if exists\n",
    "output_file = OUTPUT_PARENT_DIR / f\"{gemma_save_file_append}acts_harmless_{LANGUAGE}_{MODEL_PATH.split('/')[-1]}.npy\"\n",
    "if output_file.exists():\n",
    "    print(\"Loading harmless activations from file\")\n",
    "    harmless_acts = np.load(output_file)\n",
    "    harmless_acts = torch.from_numpy(harmless_acts)\n",
    "else:\n",
    "    # get activations for harmless instructions then save to file\n",
    "    harmless_acts = get_activations(\n",
    "        model,\n",
    "        MODEL_NAME,\n",
    "        tokenizer,\n",
    "        harmless_inst_train[:N_INST_TRAIN],\n",
    "        batch_size=BATCH_SIZE,\n",
    "        num_last_tokens=num_last_tokens,\n",
    "        gemma_relocate_even=GEMMA_RELOCATE_EVEN,\n",
    "        gemma_relocate_odd=GEMMA_RELOCATE_ODD,\n",
    "    )\n",
    "    harmless_acts = harmless_acts.cpu().float()\n",
    "    np.save(output_file, harmless_acts.numpy())\n",
    "\n",
    "print(\"Harmful Acts Shape:\", harmful_acts.shape)\n",
    "print(\"Harmless Acts Shape:\", harmless_acts.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyze the activations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tqD5E8Vc_w5d"
   },
   "outputs": [],
   "source": [
    "\n",
    "# layers x batch x tokens x dim\n",
    "harmful_acts_normed = harmful_acts.cpu().float().numpy()\n",
    "harmless_acts_normed = harmless_acts.cpu().float().numpy()\n",
    "\n",
    "hidden_dim = harmful_acts.shape[-1]\n",
    "total_ex_pt = harmful_acts.shape[0]\n",
    "\n",
    "# clean up memory\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tqdm\n",
    "import einops\n",
    "\n",
    "\n",
    "def get_clusters_single(acts_normed, k):\n",
    "    assert len(acts_normed.shape) == 3, f\"Invalid Shape: {acts_normed.shape}\" \n",
    "    # acts are (B, L, D)\n",
    "    X = acts_normed[:,-1,:] # take the last token instead of last 5 tokens\n",
    "    kmeans = KMeans(n_clusters=k, random_state=0, n_init=\"auto\").fit(X)\n",
    "    inertia = kmeans.inertia_\n",
    "    acts_normed_clusters = kmeans.cluster_centers_\n",
    "    acts_normed_labels = kmeans.labels_\n",
    "    return acts_normed_clusters, acts_normed_labels, inertia\n",
    "\n",
    "# Compute Cluster Weights (Marginals)\n",
    "def compute_cluster_weights_single(acts_normed_labels, k):\n",
    "    X = acts_normed_labels\n",
    "    acts_normed_cluster_weights = np.array([np.sum(X == cl) / len(X) for cl in range(k)])\n",
    "    return acts_normed_cluster_weights\n",
    "\n",
    "def calculate_cost_matrix(harmful_acts_normed_clusters, harmless_acts_normed_clusters):\n",
    "    return np.sum((harmful_acts_normed_clusters[:, None, :] - harmless_acts_normed_clusters[None, :, :]) ** 2, axis=-1)\n",
    "\n",
    "def calculate_kernel_matrix(cost_matrix, eps=0.1):\n",
    "    return np.exp(-cost_matrix / eps)\n",
    "\n",
    "def calculate_kernel_matrix_adaptive(cost_matrix):\n",
    "    medians = np.median(cost_matrix, axis = (-1, -2), keepdims=True)\n",
    "    return np.exp(-cost_matrix / medians)\n",
    "\n",
    "def sinkhorn_knopp_single(harmful_acts_normed_cluster_weights_sliced, harmless_acts_normed_cluster_weights_sliced, kernel_matrix_sliced, max_iter=1000, tau=1e-6):\n",
    "    # A is harmful, B is harmless? Since OT is symmetric\n",
    "    u = np.ones(k) # Row scaling vector\n",
    "    v = np.ones(k) # Column scaling vector\n",
    "\n",
    "    for t in range(max_iter):\n",
    "        v_prev = v.copy()\n",
    "        u = harmful_acts_normed_cluster_weights_sliced / (kernel_matrix_sliced @ v) # Update u: Row normalization \n",
    "        v = harmless_acts_normed_cluster_weights_sliced / (kernel_matrix_sliced.T @ u) # Update v: Column normalization \n",
    "\n",
    "        if np.linalg.norm(v - v_prev, ord=1) < tau:\n",
    "            break\n",
    "\n",
    "    P_star = np.diag(u) @ kernel_matrix_sliced @ np.diag(v)\n",
    "    return P_star\n",
    "\n",
    "\n",
    "for ex_pt in extraction_point:\n",
    "    print(\"Extraction Point:\", ex_pt)\n",
    "    harmful_acts_normed_ex_pt, harmless_acts_normed_ex_pt = harmful_acts_normed[ex_pt], harmless_acts_normed[ex_pt]\n",
    "    inertia_harmful, inertia_harmless = [], []\n",
    "    for k in tqdm.tqdm(k_grid):\n",
    "        OUTPUT_DIR = OUTPUT_PARENT_DIR / f\"k{k}_sim{sim}_lambda{lmda}_maxpc{max_pcs}\"\n",
    "\n",
    "        VISUALIZATION_DIR = VISUALIZATION_PARENT_DIR / f\"k{k}_sim{sim}_lambda{lmda}_maxpc{max_pcs}\"\n",
    "        VISUALIZATION_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        if GEMMA_RELOCATE_MODE:\n",
    "            OUTPUT_DIR_REG = OUTPUT_DIR / \"layers\" / f\"r_{str(ex_pt)}\"\n",
    "            OUTPUT_DIR_DA = OUTPUT_DIR / \"dirablate_layers\" / f\"r_{str(ex_pt)}\"\n",
    "        else:\n",
    "            OUTPUT_DIR_REG = OUTPUT_DIR / \"layers\" / str(ex_pt)\n",
    "            OUTPUT_DIR_DA = OUTPUT_DIR / \"dirablate_layers\" / str(ex_pt)\n",
    "        OUTPUT_DIR_REG.mkdir(parents=True, exist_ok=True)\n",
    "        OUTPUT_DIR_DA.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "        # Clustering\n",
    "        harmful_acts_normed_clusters_ex_pt, harmful_acts_normed_labels_ex_pt, harmful_inertia_ex_pt = get_clusters_single(harmful_acts_normed_ex_pt, k)\n",
    "        harmless_acts_normed_clusters_ex_pt, harmless_acts_normed_labels_ex_pt, harmless_inertia_ex_pt = get_clusters_single(harmless_acts_normed_ex_pt, k)\n",
    "        \n",
    "        # Storing Intertia for the elbow plots\n",
    "        inertia_harmful.append(harmful_inertia_ex_pt)\n",
    "        inertia_harmless.append(harmless_inertia_ex_pt)\n",
    "\n",
    "        # Compute Cluster Weights\n",
    "        harmful_acts_normed_cluster_weights_ex_pt = compute_cluster_weights_single(harmful_acts_normed_labels_ex_pt, k)\n",
    "        harmless_acts_normed_cluster_weights_ex_pt = compute_cluster_weights_single(harmless_acts_normed_labels_ex_pt, k)\n",
    "\n",
    "        if k > 1:\n",
    "            # Kernel Matrix\n",
    "            cost_matrix_ex_pt = calculate_cost_matrix(harmful_acts_normed_clusters_ex_pt, harmless_acts_normed_clusters_ex_pt)\n",
    "            kernel_matrix_ex_pt = calculate_kernel_matrix_adaptive(cost_matrix_ex_pt)\n",
    "\n",
    "            # P_Star Compute\n",
    "            P_star_ex_pt = sinkhorn_knopp_single(harmful_acts_normed_cluster_weights_ex_pt, harmless_acts_normed_cluster_weights_ex_pt, kernel_matrix_ex_pt, max_iter=1000, tau=1e-6)\n",
    "        else:\n",
    "            P_star_ex_pt = np.asarray([[1.0]])\n",
    "            raise NotImplementedError\n",
    "        \n",
    "        # Computation of Principal Components and Score\n",
    "        V_ex_pt = harmless_acts_normed_clusters_ex_pt[None, :, :] - harmful_acts_normed_clusters_ex_pt[:, None, :]\n",
    "        i_test, j_test = 2, 3\n",
    "        correct_test = harmless_acts_normed_clusters_ex_pt[j_test] - harmful_acts_normed_clusters_ex_pt[i_test]\n",
    "        computed_test = V_ex_pt[i_test, j_test]\n",
    "        print(\"Sanity Check:\", np.isclose(correct_test, computed_test).all())\n",
    "\n",
    "        # For Weight Mean V_bar\n",
    "        scaled_V_ex_pt = V_ex_pt * P_star_ex_pt[:, :, None]\n",
    "        V_bar_ex_pt = np.sum(scaled_V_ex_pt, axis = (0, 1))\n",
    "\n",
    "        # Centered V\n",
    "        centered_V_ex_pt = V_ex_pt - V_bar_ex_pt[None, None, :]\n",
    "\n",
    "        # Computation of Weighted Covariance Matrix\n",
    "        centered_V_reshaped_ex_pt = centered_V_ex_pt.reshape(-1, V_bar_ex_pt.shape[-1])\n",
    "        P_star_reshaped_ex_pt = P_star_ex_pt.reshape(-1, 1)\n",
    "        print(centered_V_reshaped_ex_pt.shape, P_star_reshaped_ex_pt.shape)\n",
    "        sigma_total_ex_pt = centered_V_reshaped_ex_pt.transpose(1, 0) @ (P_star_reshaped_ex_pt * centered_V_reshaped_ex_pt)\n",
    "\n",
    "        # EigenDecomposition\n",
    "        ex_pt_eigdecomp_ex_pt = np.linalg.eigh(sigma_total_ex_pt)\n",
    "\n",
    "        # Explained Variance Plot\n",
    "        eigvals_ex_pt = np.real(ex_pt_eigdecomp_ex_pt.eigenvalues)[::-1]\n",
    "        total_var_ex_pt = eigvals_ex_pt.sum()\n",
    "        eigvals_cumsum_ex_pt = np.cumsum(eigvals_ex_pt)\n",
    "        explained_variance_ex_pt = eigvals_cumsum_ex_pt[:max_pcs] / total_var_ex_pt\n",
    "\n",
    "        # Principal_components;\n",
    "        eigenvecs_ex_pt = ex_pt_eigdecomp_ex_pt.eigenvectors[:, ::-1]\n",
    "        top_k_pc_ex_pt = eigenvecs_ex_pt[:, :max_pcs].transpose()\n",
    "\n",
    "        # Plotting\n",
    "        fig = go.Figure()\n",
    "\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x = list(range(max_pcs)),\n",
    "                y = explained_variance_ex_pt,\n",
    "                mode=\"lines+markers\",\n",
    "                yaxis=\"y\",\n",
    "            )\n",
    "        )\n",
    "\n",
    "        fig.update_layout(\n",
    "\n",
    "            plot_bgcolor=\"white\",\n",
    "            title_text = f\"{MODEL_NAME}: Explained Variance\",\n",
    "            grid=dict(rows=1, columns=1),\n",
    "            xaxis=dict(\n",
    "                type=\"category\",\n",
    "                dtick=4,\n",
    "                title=dict(text=\"No of PCs\", font=dict(size=20)),\n",
    "                gridcolor=\"lightgrey\",\n",
    "                tickfont=dict(size=18),\n",
    "            ),\n",
    "            yaxis=dict(\n",
    "                title=dict(text=\"Explained_Variance\", font=dict(size=20)),\n",
    "                gridcolor=\"lightgrey\",\n",
    "                zeroline=False,\n",
    "                tickfont=dict(size=18),\n",
    "            ),\n",
    "            yaxis_range = [0, 1],\n",
    "            hovermode=\"x unified\",\n",
    "            height=300,\n",
    "            width=600,\n",
    "        )\n",
    "        fig.show()\n",
    "        fig.write_image(VISUALIZATION_DIR / f\"{gemma_save_file_append}explained_variance_plot.pdf\", scale=5)\n",
    "\n",
    "        # Compute pc scores:\n",
    "        pc_scores_ex_pt = einops.einsum(top_k_pc_ex_pt, centered_V_ex_pt, \"k d, i j d -> k i j\")\n",
    "        i, j, k = 2, 3, 4\n",
    "        correct_val_ex_pt = np.dot(top_k_pc_ex_pt[k], centered_V_ex_pt[i, j])\n",
    "        computed_val_ex_pt = pc_scores_ex_pt[k, i, j]\n",
    "        print(\"Check 1:\", np.isclose(correct_val_ex_pt, computed_val_ex_pt))\n",
    "\n",
    "        # Create and Save Config (Base)\n",
    "        steering_config = {}\n",
    "        lay, pos = ex_pt // 2, ex_pt % 2\n",
    "        if not GEMMA_RELOCATE_MODE:\n",
    "            if pos == 0:\n",
    "                if \"gemma\" in MODEL_NAME or \"Gemma\" in MODEL_NAME:\n",
    "                    module_name_base = f'model.layers.{lay}.pre_feedforward_layernorm'\n",
    "                else:\n",
    "                    module_name_base = f'model.layers.{lay}.post_attention_layernorm'\n",
    "            else:\n",
    "                module_name_base = f'model.layers.{lay + 1}.input_layernorm'\n",
    "        else:\n",
    "            if pos == 0:\n",
    "                module_name_base = f'model.layers.{lay}.post_attention_layernorm'\n",
    "            else:\n",
    "                module_name_base = f'model.layers.{lay}.post_feedforward_layernorm'\n",
    "        steering_config[module_name_base] = {\n",
    "            \"mode\": \"nonparametric_steering\",\n",
    "            \"source_acts_normed_clusters\": harmful_acts_normed_clusters_ex_pt, \n",
    "            \"transport_plan\": P_star_ex_pt,\n",
    "            \"v_bar\": V_bar_ex_pt,\n",
    "            \"pc_scores\": pc_scores_ex_pt,\n",
    "            \"top_K_pc\": top_k_pc_ex_pt,\n",
    "        }\n",
    "\n",
    "        output_name_base = OUTPUT_DIR_REG / f\"steering_config-en-{ex_pt}.npy\"\n",
    "        np.save(output_name_base, steering_config)\n",
    "\n",
    "        steering_config_da = {}\n",
    "        for j in range(total_ex_pt):\n",
    "            lay_j, pos_j = j // 2, j % 2\n",
    "            if not GEMMA_RELOCATE_MODE:\n",
    "                if pos_j == 0:\n",
    "                    if \"gemma\" in MODEL_NAME or \"Gemma\" in MODEL_NAME:\n",
    "                        module_name_j = f'model.layers.{lay_j}.pre_feedforward_layernorm'\n",
    "                    else:\n",
    "                        module_name_j = f'model.layers.{lay_j}.post_attention_layernorm'\n",
    "                else:\n",
    "                    module_name_j = f'model.layers.{lay_j + 1}.input_layernorm'\n",
    "            else:\n",
    "                if pos_j == 0:\n",
    "                    module_name_j = f'model.layers.{lay_j}.post_attention_layernorm'\n",
    "                else:\n",
    "                    module_name_j = f'model.layers.{lay_j}.post_feedforward_layernorm'\n",
    "            steering_config_da[module_name_j] = {\n",
    "                \"mode\": \"nonparametric_steering\",\n",
    "                \"source_acts_normed_clusters\": harmful_acts_normed_clusters_ex_pt, \n",
    "                \"transport_plan\": P_star_ex_pt,\n",
    "                \"v_bar\": V_bar_ex_pt,\n",
    "                \"pc_scores\": pc_scores_ex_pt,\n",
    "                \"top_K_pc\": top_k_pc_ex_pt,\n",
    "            }\n",
    "\n",
    "        output_name_da = OUTPUT_DIR_DA / f\"steering_config-en-{ex_pt}.npy\"\n",
    "        np.save(output_name_da, steering_config_da)\n",
    "    \n",
    "    # fig = plt.figure()\n",
    "    # plt.plot(k_grid, inertia_harmful, label=\"Harmful\")\n",
    "    # plt.plot(k_grid, inertia_harmless, label=\"Harmless\")\n",
    "    # plt.xlabel(\"Number of Clusters (k)\")\n",
    "    # plt.ylabel(\"Inertia\")\n",
    "    # plt.title(f\"Inertia vs Number of Clusters - {MODEL_NAME}\")\n",
    "    # if GEMMA_RELOCATE_MODE:\n",
    "    #     elbow_plot_save_path = VISUALIZATION_DIR / f\"elbow_plot_layer_r_{extraction_point}.pdf\"\n",
    "    # else:\n",
    "    #     elbow_plot_save_path = VISUALIZATION_DIR / f\"elbow_plot_layer_{extraction_point}.pdf\"\n",
    "    # plt.legend()\n",
    "    # fig.savefig(elbow_plot_save_path)\n",
    "    # if ex_pt == extraction_point[-1]:\n",
    "    #     plt.show()\n",
    "\n",
    "steering_config_da"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "lda_steering (3.11.14)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "3877270cf4bc42a9b6142cce7a5d8c54": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ec8f6f360a2243b0ac98d34e825ba378",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f2ee188bfaa84e9680dbc296b1adbef6",
      "value": 2
     }
    },
    "89797f6e82104058af92e3ceb094af66": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "89ee88168c474e9fbcf4a17f1483eff4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c362d50107dd4a2db0d1a79da2af8d57",
      "placeholder": "​",
      "style": "IPY_MODEL_ffa85c694b694425999187b346c7ecfe",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "9a5611a341ed4673aaaf2f463f685d7c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_e973493cd6d14381bb4ad2f82417e8a9",
      "placeholder": "​",
      "style": "IPY_MODEL_89797f6e82104058af92e3ceb094af66",
      "value": " 2/2 [00:18&lt;00:00,  8.85s/it]"
     }
    },
    "a2de63dfbd6c485e841c6fcd1fefe451": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ad063e2c68a44f009bfab68c141c09be": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_89ee88168c474e9fbcf4a17f1483eff4",
       "IPY_MODEL_3877270cf4bc42a9b6142cce7a5d8c54",
       "IPY_MODEL_9a5611a341ed4673aaaf2f463f685d7c"
      ],
      "layout": "IPY_MODEL_a2de63dfbd6c485e841c6fcd1fefe451"
     }
    },
    "c362d50107dd4a2db0d1a79da2af8d57": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e973493cd6d14381bb4ad2f82417e8a9": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ec8f6f360a2243b0ac98d34e825ba378": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f2ee188bfaa84e9680dbc296b1adbef6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "ffa85c694b694425999187b346c7ecfe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
